Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
Β·
6228595
1
Parent(s):
4ab6146
yes
Browse files- app.py +478 -637
- install.sh +97 -0
- requirements.txt +19 -4
- setup.py +31 -0
app.py
CHANGED
@@ -3,9 +3,54 @@ Mirel Harmony Inference β HF Space (Gradio)
|
|
3 |
ZeroGPU-ready, Harmony formatting, MX format support for GPT-OSS-20B
|
4 |
Proper LoRA adapter loading and conversion for MX compatibility
|
5 |
Single file: app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from __future__ import annotations
|
8 |
-
import os, gc, json,
|
9 |
from dataclasses import dataclass
|
10 |
from typing import List, Dict, Optional, Any, Union
|
11 |
from datetime import datetime
|
@@ -14,7 +59,7 @@ import spaces # required for ZeroGPU
|
|
14 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
15 |
import numpy as np
|
16 |
|
17 |
-
# Suppress warnings
|
18 |
warnings.filterwarnings("ignore", message=".*microscaling.*")
|
19 |
warnings.filterwarnings("ignore", message=".*mx.*")
|
20 |
|
@@ -32,770 +77,566 @@ try:
|
|
32 |
ReasoningEffort
|
33 |
)
|
34 |
HARMONY_AVAILABLE = True
|
|
|
35 |
except ImportError:
|
36 |
-
print("
|
37 |
HARMONY_AVAILABLE = False
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
|
44 |
-
ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b")
|
45 |
-
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516")
|
46 |
ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
|
47 |
-
|
48 |
-
|
49 |
-
ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "
|
|
|
50 |
|
51 |
-
#
|
52 |
IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower()
|
53 |
-
USE_MX_FORMAT =
|
54 |
|
55 |
-
# Harmony channels for
|
56 |
REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
|
57 |
|
58 |
-
# HF
|
59 |
-
HF_TOKEN
|
60 |
os.getenv("HF_TOKEN")
|
61 |
or os.getenv("HUGGING_FACE_HUB_TOKEN")
|
62 |
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
63 |
or os.getenv("HF_ACCESS_TOKEN")
|
64 |
)
|
65 |
|
66 |
-
def _hf_login()
|
67 |
-
"""Login to
|
68 |
if HF_TOKEN:
|
69 |
try:
|
70 |
from huggingface_hub import login, whoami
|
71 |
login(token=HF_TOKEN, add_to_git_credential=True)
|
72 |
try:
|
73 |
-
|
74 |
-
print(f"
|
75 |
-
except
|
76 |
-
print("
|
77 |
except Exception as e:
|
78 |
-
print(f"
|
79 |
else:
|
80 |
-
print("
|
81 |
|
82 |
-
# Login before loading
|
83 |
_hf_login()
|
84 |
|
|
|
85 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
86 |
|
87 |
-
#
|
88 |
-
if HARMONY_AVAILABLE:
|
89 |
-
harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
90 |
-
else:
|
91 |
-
harmony_encoding = None
|
92 |
-
|
93 |
-
# Stop tokens per Harmony spec: <|return|> (200002), <|call|> (200012)
|
94 |
-
HARMONY_STOP_IDS = harmony_encoding.stop_tokens_for_assistant_actions() if HARMONY_AVAILABLE else []
|
95 |
-
|
96 |
-
# Tokenizer is lightweight; load once
|
97 |
try:
|
98 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
|
99 |
-
print(f"
|
100 |
except Exception as e:
|
101 |
-
print(f"
|
102 |
raise
|
103 |
|
104 |
-
#
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
_HAS_PEFT = False
|
112 |
-
print("[Warning] PEFT not available. Install with: pip install peft")
|
113 |
|
114 |
-
#
|
115 |
-
try:
|
116 |
-
import msamp
|
117 |
-
_HAS_MSAMP = True
|
118 |
-
print("[Info] Microsoft AMP (msamp) available for MX format support")
|
119 |
-
except ImportError:
|
120 |
-
_HAS_MSAMP = False
|
121 |
-
print("[Info] msamp not available - using fallback MX handling")
|
122 |
-
|
123 |
-
# -----------------------
|
124 |
-
# MX Format Conversion
|
125 |
-
# -----------------------
|
126 |
-
def convert_fp32_lora_to_mx_compatible(lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
127 |
-
"""
|
128 |
-
Convert fp32 LoRA weights to be compatible with MX format base model.
|
129 |
-
MX models expect specific dtype handling.
|
130 |
-
"""
|
131 |
-
converted = {}
|
132 |
-
|
133 |
-
for key, tensor in lora_state_dict.items():
|
134 |
-
if tensor is None:
|
135 |
-
converted[key] = tensor
|
136 |
-
continue
|
137 |
-
|
138 |
-
# LoRA weights (lora_A, lora_B) need special handling
|
139 |
-
if 'lora_' in key:
|
140 |
-
# For MX compatibility, we keep weights in fp32 but ensure proper scaling
|
141 |
-
# MX format internally handles quantization, we just need clean fp32 inputs
|
142 |
-
if tensor.dtype != torch.float32:
|
143 |
-
tensor = tensor.to(torch.float32)
|
144 |
-
|
145 |
-
# Ensure weights are in reasonable range for MX quantization
|
146 |
-
# MX format works best with weights in [-1, 1] range
|
147 |
-
if 'lora_A' in key:
|
148 |
-
# Input projection - initialize with small values
|
149 |
-
std = 1.0 / torch.sqrt(torch.tensor(tensor.shape[1], dtype=torch.float32))
|
150 |
-
if tensor.std() > std * 10: # If weights are too large
|
151 |
-
print(f"[MX Convert] Scaling down {key} from std={tensor.std():.4f} to {std:.4f}")
|
152 |
-
tensor = tensor * (std / tensor.std())
|
153 |
-
elif 'lora_B' in key:
|
154 |
-
# Output projection - should be near zero initially
|
155 |
-
if tensor.abs().max() > 0.1:
|
156 |
-
print(f"[MX Convert] Scaling down {key} max={tensor.abs().max():.4f}")
|
157 |
-
tensor = tensor * 0.01
|
158 |
-
|
159 |
-
converted[key] = tensor
|
160 |
-
else:
|
161 |
-
# Non-LoRA weights (like embeddings) stay as-is
|
162 |
-
converted[key] = tensor
|
163 |
-
|
164 |
-
return converted
|
165 |
|
166 |
-
def
|
167 |
-
"""
|
168 |
-
|
169 |
-
|
170 |
-
"""
|
171 |
-
if not _HAS_PEFT:
|
172 |
-
raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft")
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
-
|
|
|
187 |
|
188 |
-
#
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
#
|
197 |
-
try:
|
198 |
-
adapter_weights_path = hf_hub_download(
|
199 |
-
repo_id=adapter_path,
|
200 |
-
filename="adapter_model.safetensors",
|
201 |
-
subfolder=subfolder,
|
202 |
-
token=HF_TOKEN
|
203 |
-
)
|
204 |
-
adapter_weights = load_file(adapter_weights_path)
|
205 |
-
print(f"[LoRA] Loaded safetensors weights from {subfolder}")
|
206 |
-
except Exception:
|
207 |
-
# Try .bin format
|
208 |
-
adapter_weights_path = hf_hub_download(
|
209 |
-
repo_id=adapter_path,
|
210 |
-
filename="adapter_model.bin",
|
211 |
-
subfolder=subfolder,
|
212 |
-
token=HF_TOKEN
|
213 |
-
)
|
214 |
-
adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
|
215 |
-
print(f"[LoRA] Loaded bin weights from {subfolder}")
|
216 |
else:
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
print("[LoRA] Loaded local safetensors weights")
|
224 |
-
elif osp.exists(local_bin):
|
225 |
-
adapter_weights = torch.load(local_bin, map_location="cpu")
|
226 |
-
print("[LoRA] Loaded local bin weights")
|
227 |
-
else:
|
228 |
-
# Try downloading from HF Hub
|
229 |
-
try:
|
230 |
-
adapter_weights_path = hf_hub_download(
|
231 |
-
repo_id=adapter_path,
|
232 |
-
filename="adapter_model.safetensors",
|
233 |
-
token=HF_TOKEN
|
234 |
-
)
|
235 |
-
adapter_weights = load_file(adapter_weights_path)
|
236 |
-
print("[LoRA] Downloaded safetensors weights from Hub")
|
237 |
-
except Exception:
|
238 |
-
adapter_weights_path = hf_hub_download(
|
239 |
-
repo_id=adapter_path,
|
240 |
-
filename="adapter_model.bin",
|
241 |
-
token=HF_TOKEN
|
242 |
-
)
|
243 |
-
adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
|
244 |
-
print("[LoRA] Downloaded bin weights from Hub")
|
245 |
-
|
246 |
-
except Exception as e:
|
247 |
-
raise FileNotFoundError(f"Could not load adapter weights: {e}")
|
248 |
-
|
249 |
-
# Convert weights for MX compatibility
|
250 |
-
print("[LoRA] Converting fp32 weights for MX format compatibility...")
|
251 |
-
adapter_weights = convert_fp32_lora_to_mx_compatible(adapter_weights)
|
252 |
|
253 |
-
#
|
254 |
-
|
255 |
|
256 |
-
#
|
257 |
-
|
258 |
-
|
259 |
-
model
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
|
265 |
-
#
|
266 |
-
model.
|
|
|
|
|
267 |
|
268 |
-
print("[LoRA] Successfully attached LoRA adapter with MX compatibility")
|
269 |
return model
|
270 |
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
"""Build kwargs for model loading with MX format support."""
|
276 |
-
kw: Dict[str, Any] = dict(
|
277 |
-
device_map=device_map,
|
278 |
-
trust_remote_code=True,
|
279 |
-
low_cpu_mem_usage=True,
|
280 |
-
token=HF_TOKEN,
|
281 |
-
)
|
282 |
-
|
283 |
-
if IS_GPT_OSS and USE_MX_FORMAT:
|
284 |
-
# GPT-OSS models use MX format
|
285 |
-
# Don't specify torch_dtype - let the model use its native MX format
|
286 |
-
print("[Model] Using MX format for GPT-OSS model")
|
287 |
-
kw.update({
|
288 |
-
"attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
|
289 |
-
# MX models handle their own dtype internally
|
290 |
-
# Don't force a dtype here
|
291 |
-
})
|
292 |
-
else:
|
293 |
-
# Non-MX models
|
294 |
-
kw.update({
|
295 |
-
"torch_dtype": torch.float16, # Use fp16 for non-MX models
|
296 |
-
"attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
|
297 |
-
})
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
print(f"
|
304 |
|
305 |
-
#
|
306 |
-
|
307 |
|
308 |
-
#
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
hasattr(config, 'torch_dtype') and 'mx' in str(config.torch_dtype).lower()
|
313 |
-
)
|
314 |
|
315 |
-
|
316 |
-
|
|
|
|
|
317 |
|
318 |
-
#
|
319 |
-
|
320 |
-
model = AutoModelForCausalLM.from_pretrained(
|
321 |
-
MODEL_ID,
|
322 |
-
config=config,
|
323 |
-
trust_remote_code=True,
|
324 |
-
device_map=device_map,
|
325 |
-
low_cpu_mem_usage=True,
|
326 |
-
token=HF_TOKEN,
|
327 |
-
# Let the model handle its own dtype
|
328 |
-
attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager",
|
329 |
-
)
|
330 |
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
-
|
336 |
-
|
337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
-
#
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
# Standard PEFT loading for non-MX models
|
347 |
-
if not _HAS_PEFT:
|
348 |
-
raise RuntimeError("PEFT is required when ADAPTER_ID is set.")
|
349 |
-
print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...")
|
350 |
-
peft_kwargs = {"token": HF_TOKEN, "is_trainable": False}
|
351 |
-
if ADAPTER_SUBFOLDER:
|
352 |
-
peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
|
353 |
-
print(f"[Model] Using subfolder: {ADAPTER_SUBFOLDER}")
|
354 |
-
model = PeftModel.from_pretrained(
|
355 |
-
model,
|
356 |
-
ADAPTER_ID,
|
357 |
-
**peft_kwargs
|
358 |
-
)
|
359 |
-
|
360 |
-
print("[Model] Successfully loaded with LoRA adapter")
|
361 |
-
|
362 |
-
# Optionally merge adapter for better performance
|
363 |
-
merge_adapter = os.getenv("MERGE_ADAPTER", "0") == "1"
|
364 |
-
if merge_adapter and hasattr(model, 'merge_and_unload'):
|
365 |
-
print("[Model] Merging adapter into base model...")
|
366 |
-
model = model.merge_and_unload()
|
367 |
-
print("[Model] Adapter merged successfully")
|
368 |
-
|
369 |
-
except Exception as e:
|
370 |
-
print(f"[Error] Failed to load adapter: {e}")
|
371 |
-
print("[Warning] Continuing with base model only")
|
372 |
|
373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
|
375 |
-
#
|
376 |
-
|
377 |
-
|
378 |
-
|
|
|
|
|
379 |
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
|
387 |
-
"""Build a Harmony-formatted prompt."""
|
388 |
-
if HARMONY_AVAILABLE and harmony_encoding is not None:
|
389 |
-
effort_map = {"low": ReasoningEffort.LOW, "medium": ReasoningEffort.MEDIUM, "high": ReasoningEffort.HIGH}
|
390 |
-
effort = effort_map.get(str(reasoning_effort).lower(), ReasoningEffort.HIGH)
|
391 |
-
|
392 |
-
system_content = (
|
393 |
-
SystemContent.new()
|
394 |
-
.with_model_identity("You are ChatGPT, a large language model trained by OpenAI.")
|
395 |
-
.with_reasoning_effort(effort)
|
396 |
-
.with_conversation_start_date(datetime.now().strftime("%Y-%m-%d"))
|
397 |
-
.with_knowledge_cutoff("2024-06")
|
398 |
-
.with_required_channels(REQUIRED_CHANNELS)
|
399 |
)
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
elif role == "assistant":
|
416 |
-
harmony_messages.append(
|
417 |
-
Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final")
|
418 |
-
)
|
419 |
-
|
420 |
-
convo = Conversation.from_messages(harmony_messages)
|
421 |
-
return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
|
422 |
-
|
423 |
-
# Fallback: tokenizer chat template
|
424 |
-
if not messages or messages[0].get("role") != "system":
|
425 |
-
messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
|
426 |
-
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
427 |
|
428 |
def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
|
429 |
-
"""Parse response tokens
|
430 |
-
if not HARMONY_AVAILABLE:
|
431 |
text = tokenizer.decode(tokens, skip_special_tokens=False)
|
432 |
-
return {"final":
|
433 |
-
|
434 |
-
parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
|
435 |
-
|
436 |
-
channels = {}
|
437 |
-
for msg in parsed_messages:
|
438 |
-
channel = msg.channel if hasattr(msg, 'channel') else "final"
|
439 |
-
if channel not in channels:
|
440 |
-
channels[channel] = ""
|
441 |
-
channels[channel] += "".join([getattr(part, "text", str(part)) for part in (msg.content if isinstance(msg.content, list) else [msg.content])])
|
442 |
-
|
443 |
-
if "final" not in channels:
|
444 |
-
channels["final"] = " ".join(channels.values())
|
445 |
|
446 |
-
return channels
|
447 |
-
|
448 |
-
def extract_final_channel_fallback(text: str) -> str:
|
449 |
-
"""Extract the <final> channel from decoded Harmony text."""
|
450 |
try:
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
return text.strip()
|
480 |
|
481 |
-
#
|
482 |
-
# Rose guidance
|
483 |
-
# -----------------------
|
484 |
-
def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
|
485 |
-
"""Create vocab bias from {token: weight}."""
|
486 |
-
vocab_size = len(tokenizer)
|
487 |
-
bias = torch.zeros(vocab_size, dtype=torch.float32)
|
488 |
-
for tok, w in mapping.items():
|
489 |
-
if tok is None:
|
490 |
-
continue
|
491 |
-
tid = tokenizer.convert_tokens_to_ids(tok)
|
492 |
-
if isinstance(tid, list):
|
493 |
-
for t in tid:
|
494 |
-
if isinstance(t, int) and t >= 0:
|
495 |
-
bias[t] += float(w) / max(1, len(tid))
|
496 |
-
elif isinstance(tid, int) and tid >= 0:
|
497 |
-
bias[tid] += float(w)
|
498 |
-
return bias
|
499 |
-
|
500 |
-
class RoseGuidedLogits(torch.nn.Module):
|
501 |
-
def __init__(self, bias_vec: torch.Tensor, alpha: float = 1.0):
|
502 |
-
super().__init__()
|
503 |
-
self.bias_vec = bias_vec
|
504 |
-
self.alpha = float(alpha)
|
505 |
-
|
506 |
-
def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
507 |
-
return scores + self.alpha * self.bias_vec.to(scores.device)
|
508 |
|
509 |
-
# -----------------------
|
510 |
-
# Generation
|
511 |
-
# -----------------------
|
512 |
@spaces.GPU(duration=120)
|
513 |
-
def
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
|
|
|
|
|
|
|
|
520 |
try:
|
|
|
521 |
if seed is not None:
|
522 |
torch.manual_seed(int(seed))
|
523 |
-
|
524 |
-
# Load model with MX support
|
525 |
-
model = _load_model_on("auto")
|
526 |
|
527 |
-
#
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
|
|
|
|
|
|
534 |
# Prepare inputs
|
535 |
device = next(model.parameters()).device
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
prompt_len = input_ids.shape[1]
|
541 |
else:
|
542 |
-
|
543 |
-
inputs =
|
544 |
-
|
545 |
-
|
546 |
-
|
|
|
547 |
|
548 |
# Generate
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
)
|
565 |
|
566 |
# Extract generated tokens
|
567 |
-
|
568 |
-
gen_ids = out_list[prompt_len:]
|
569 |
|
570 |
# Truncate at stop tokens
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
break
|
576 |
|
577 |
# Parse response
|
578 |
-
|
579 |
-
try:
|
580 |
-
channels = parse_harmony_response(gen_ids)
|
581 |
-
except Exception:
|
582 |
-
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
583 |
-
channels = {
|
584 |
-
"final": extract_final_channel_fallback(decoded),
|
585 |
-
"raw": decoded
|
586 |
-
}
|
587 |
-
else:
|
588 |
-
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
589 |
-
channels = {
|
590 |
-
"final": extract_final_channel_fallback(decoded),
|
591 |
-
"raw": decoded
|
592 |
-
}
|
593 |
|
594 |
return channels
|
595 |
-
|
596 |
except Exception as e:
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
finally:
|
602 |
# Cleanup
|
603 |
-
|
604 |
del model
|
605 |
-
except:
|
606 |
-
pass
|
607 |
gc.collect()
|
608 |
if torch.cuda.is_available():
|
609 |
torch.cuda.empty_cache()
|
610 |
|
611 |
-
#
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
622 |
try:
|
623 |
-
# Build
|
624 |
-
messages = [{"role": "system", "content": system_prompt or
|
625 |
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
|
635 |
-
|
|
|
636 |
|
637 |
# Create prompt
|
638 |
-
|
639 |
-
|
640 |
-
else:
|
641 |
-
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
642 |
-
|
643 |
-
# Build Rose map
|
644 |
-
rose_map: Optional[Dict[str, float]] = None
|
645 |
-
if rose_enable:
|
646 |
-
rose_map = {}
|
647 |
-
tok_str = (rose_tokens or "").strip()
|
648 |
-
if tok_str:
|
649 |
-
for p in [p.strip() for p in tok_str.split(",") if p.strip()]:
|
650 |
-
if ":" in p:
|
651 |
-
k, v = p.split(":", 1)
|
652 |
-
try:
|
653 |
-
rose_map[k.strip()] = float(v)
|
654 |
-
except:
|
655 |
-
pass
|
656 |
-
if rose_json:
|
657 |
-
try:
|
658 |
-
j = json.loads(rose_json)
|
659 |
-
if isinstance(j, dict):
|
660 |
-
for k, v in j.items():
|
661 |
-
try:
|
662 |
-
rose_map[str(k)] = float(v)
|
663 |
-
except:
|
664 |
-
pass
|
665 |
-
except:
|
666 |
-
pass
|
667 |
-
if not rose_map:
|
668 |
-
rose_map = None
|
669 |
-
|
670 |
# Generate
|
671 |
-
channels =
|
672 |
prompt,
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
"no_repeat_ngram_size": 6,
|
681 |
-
},
|
682 |
-
rose_map,
|
683 |
-
float(rose_alpha),
|
684 |
-
float(rose_score) if rose_score is not None else None,
|
685 |
-
int(seed) if seed is not None else None,
|
686 |
)
|
687 |
|
688 |
# Format response
|
689 |
-
if show_thinking:
|
690 |
response = "## Chain of Thought:\n\n"
|
691 |
for channel, content in channels.items():
|
692 |
if channel != "final" and content:
|
693 |
-
response += f"### {channel.capitalize()}
|
694 |
-
response += f"### Final Response:\n{channels.get('final', 'No
|
695 |
-
return response
|
696 |
else:
|
697 |
-
|
698 |
-
|
699 |
-
except Exception as e:
|
700 |
-
import traceback
|
701 |
-
return f"[Error] {type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
|
702 |
-
|
703 |
-
# -----------------------
|
704 |
-
# UI
|
705 |
-
# -----------------------
|
706 |
-
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
707 |
-
gr.Markdown(
|
708 |
-
f"""
|
709 |
-
# Mirel β Harmony Chain-of-Thought Inference
|
710 |
|
711 |
-
|
712 |
-
**Adapter**: {ADAPTER_ID or 'None'}
|
713 |
-
**Status**: {'β
Harmony Available' if HARMONY_AVAILABLE else 'β οΈ Harmony Not Installed'}
|
714 |
|
715 |
-
|
716 |
-
""
|
717 |
-
)
|
718 |
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
725 |
|
726 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
with gr.Row():
|
728 |
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
729 |
-
top_p = gr.Slider(0.
|
730 |
-
top_k = gr.Slider(0, 200, value=
|
|
|
731 |
with gr.Row():
|
732 |
-
|
733 |
-
|
734 |
seed = gr.Number(value=None, label="Seed (optional)", precision=0)
|
|
|
735 |
with gr.Row():
|
|
|
|
|
736 |
reasoning_effort = gr.Radio(
|
737 |
-
|
738 |
value="high",
|
739 |
-
label="Reasoning
|
740 |
-
info="How much thinking the model should do"
|
741 |
-
)
|
742 |
-
show_thinking = gr.Checkbox(
|
743 |
-
value=False,
|
744 |
-
label="Show thinking channels",
|
745 |
-
info="Display all internal reasoning channels"
|
746 |
)
|
747 |
|
748 |
-
with gr.Accordion("Rose Guidance (Optional)", open=False):
|
749 |
-
gr.Markdown("Fine-tune generation with token biases")
|
750 |
-
with gr.Row():
|
751 |
-
rose_enable = gr.Checkbox(value=False, label="Enable Rose bias")
|
752 |
-
rose_alpha = gr.Slider(0.0, 5.0, value=1.0, step=0.05, label="Alpha (strength)")
|
753 |
-
rose_score = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Score multiplier")
|
754 |
-
rose_tokens = gr.Textbox(
|
755 |
-
label="Token:weight pairs",
|
756 |
-
placeholder="example:1.5, test:-0.5",
|
757 |
-
value=""
|
758 |
-
)
|
759 |
-
rose_json = gr.Textbox(
|
760 |
-
label="JSON weights",
|
761 |
-
placeholder='{"token": 1.0, "another": -0.5}',
|
762 |
-
value=""
|
763 |
-
)
|
764 |
-
|
765 |
# Chat interface
|
766 |
chat = gr.ChatInterface(
|
767 |
-
fn=
|
768 |
-
type="messages",
|
769 |
additional_inputs=[
|
770 |
-
system_prompt,
|
771 |
-
|
772 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
773 |
],
|
774 |
-
title=
|
775 |
-
description="Chain-of-thought model with MX format support",
|
776 |
examples=[
|
777 |
["Hello! Can you introduce yourself?"],
|
778 |
-
["What
|
779 |
-
["Explain quantum computing
|
780 |
-
["
|
781 |
],
|
782 |
cache_examples=False,
|
783 |
)
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
if __name__ == "__main__":
|
797 |
-
|
798 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
server_port=7860,
|
800 |
share=False
|
801 |
)
|
|
|
3 |
ZeroGPU-ready, Harmony formatting, MX format support for GPT-OSS-20B
|
4 |
Proper LoRA adapter loading and conversion for MX compatibility
|
5 |
Single file: app.py
|
6 |
+
|
7 |
+
Requirements:
|
8 |
+
huggingface_hub>=0.34.0
|
9 |
+
transformers>=4.55.0
|
10 |
+
accelerate>=0.33.0
|
11 |
+
peft>=0.11.0
|
12 |
+
torch>=2.4.0
|
13 |
+
bitsandbytes>=0.43.1
|
14 |
+
openai-harmony
|
15 |
+
gradio>=5.42.0
|
16 |
+
triton>=3.4.0
|
17 |
+
git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
|
18 |
"""
|
19 |
+
|
20 |
+
# ===== SETUP: Ensure triton_kernels is installed for MX format =====
|
21 |
+
import subprocess
|
22 |
+
import sys
|
23 |
+
|
24 |
+
def ensure_triton_kernels():
|
25 |
+
"""Ensure triton_kernels is installed for MX format support on H200."""
|
26 |
+
try:
|
27 |
+
import triton_kernels
|
28 |
+
print("β triton_kernels already installed - MX format supported")
|
29 |
+
return True
|
30 |
+
except ImportError:
|
31 |
+
print("Installing triton_kernels for MX format support...")
|
32 |
+
try:
|
33 |
+
subprocess.check_call([
|
34 |
+
sys.executable, "-m", "pip", "install",
|
35 |
+
"git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
|
36 |
+
])
|
37 |
+
print("β triton_kernels installed successfully")
|
38 |
+
# Force reimport
|
39 |
+
import importlib
|
40 |
+
import site
|
41 |
+
importlib.reload(site)
|
42 |
+
return True
|
43 |
+
except subprocess.CalledProcessError as e:
|
44 |
+
print(f"β Failed to install triton_kernels: {e}")
|
45 |
+
print("ERROR: MX format will NOT work properly without triton_kernels!")
|
46 |
+
return False
|
47 |
+
|
48 |
+
# Install triton_kernels before other imports
|
49 |
+
_TRITON_INSTALL_SUCCESS = ensure_triton_kernels()
|
50 |
+
|
51 |
+
# ===== MAIN IMPORTS =====
|
52 |
from __future__ import annotations
|
53 |
+
import os, gc, json, torch, warnings, traceback
|
54 |
from dataclasses import dataclass
|
55 |
from typing import List, Dict, Optional, Any, Union
|
56 |
from datetime import datetime
|
|
|
59 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
60 |
import numpy as np
|
61 |
|
62 |
+
# Suppress warnings
|
63 |
warnings.filterwarnings("ignore", message=".*microscaling.*")
|
64 |
warnings.filterwarnings("ignore", message=".*mx.*")
|
65 |
|
|
|
77 |
ReasoningEffort
|
78 |
)
|
79 |
HARMONY_AVAILABLE = True
|
80 |
+
print("β OpenAI Harmony loaded successfully")
|
81 |
except ImportError:
|
82 |
+
print("β openai_harmony not installed. Install with: pip install openai-harmony")
|
83 |
HARMONY_AVAILABLE = False
|
84 |
|
85 |
+
# Import PEFT for LoRA support
|
86 |
+
try:
|
87 |
+
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
|
88 |
+
_HAS_PEFT = True
|
89 |
+
print("β PEFT loaded successfully")
|
90 |
+
except Exception:
|
91 |
+
_HAS_PEFT = False
|
92 |
+
print("β PEFT not available. Install with: pip install peft")
|
93 |
+
|
94 |
+
# Check for triton_kernels (required for MX format)
|
95 |
+
try:
|
96 |
+
import triton_kernels
|
97 |
+
_HAS_TRITON_KERNELS = True
|
98 |
+
print("β triton_kernels loaded - MX format enabled")
|
99 |
+
except ImportError:
|
100 |
+
_HAS_TRITON_KERNELS = False
|
101 |
+
print("β triton_kernels not available - MX format disabled!")
|
102 |
+
|
103 |
+
# ===== CONFIGURATION =====
|
104 |
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
|
105 |
+
ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b")
|
106 |
+
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516")
|
107 |
ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
|
108 |
+
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
|
109 |
+
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
|
110 |
+
ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "1")) == "1"
|
111 |
+
MERGE_ADAPTER = os.getenv("MERGE_ADAPTER", "0") == "1"
|
112 |
|
113 |
+
# Detect if using GPT-OSS model
|
114 |
IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower()
|
115 |
+
USE_MX_FORMAT = IS_GPT_OSS and _HAS_TRITON_KERNELS
|
116 |
|
117 |
+
# Harmony channels for chain-of-thought
|
118 |
REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
|
119 |
|
120 |
+
# HF Authentication
|
121 |
+
HF_TOKEN = (
|
122 |
os.getenv("HF_TOKEN")
|
123 |
or os.getenv("HUGGING_FACE_HUB_TOKEN")
|
124 |
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
125 |
or os.getenv("HF_ACCESS_TOKEN")
|
126 |
)
|
127 |
|
128 |
+
def _hf_login():
|
129 |
+
"""Login to HuggingFace Hub."""
|
130 |
if HF_TOKEN:
|
131 |
try:
|
132 |
from huggingface_hub import login, whoami
|
133 |
login(token=HF_TOKEN, add_to_git_credential=True)
|
134 |
try:
|
135 |
+
user = whoami(token=HF_TOKEN)
|
136 |
+
print(f"β Logged in as: {user.get('name', user.get('id', 'unknown'))}")
|
137 |
+
except:
|
138 |
+
print("β HF login successful")
|
139 |
except Exception as e:
|
140 |
+
print(f"β HF login failed: {e}")
|
141 |
else:
|
142 |
+
print("β No HF_TOKEN found in environment")
|
143 |
|
144 |
+
# Login before loading models
|
145 |
_hf_login()
|
146 |
|
147 |
+
# Disable tokenizer parallelism warning
|
148 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
149 |
|
150 |
+
# ===== LOAD TOKENIZER =====
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
try:
|
152 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
|
153 |
+
print(f"β Tokenizer loaded from {MODEL_ID}")
|
154 |
except Exception as e:
|
155 |
+
print(f"β Failed to load tokenizer: {e}")
|
156 |
raise
|
157 |
|
158 |
+
# ===== HARMONY SETUP =====
|
159 |
+
if HARMONY_AVAILABLE:
|
160 |
+
harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
161 |
+
HARMONY_STOP_IDS = harmony_encoding.stop_tokens_for_assistant_actions()
|
162 |
+
else:
|
163 |
+
harmony_encoding = None
|
164 |
+
HARMONY_STOP_IDS = []
|
|
|
|
|
165 |
|
166 |
+
# ===== MODEL LOADING WITH MX FORMAT SUPPORT =====
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
def detect_mx_format(model) -> bool:
|
169 |
+
"""Check if model is using native MX format."""
|
170 |
+
if not hasattr(model, 'model') or not hasattr(model.model, 'layers'):
|
171 |
+
return False
|
|
|
|
|
|
|
172 |
|
173 |
+
try:
|
174 |
+
first_layer = model.model.layers[0]
|
175 |
+
if hasattr(first_layer, 'block_sparse_moe'):
|
176 |
+
expert = first_layer.block_sparse_moe.experts[0]
|
177 |
+
if hasattr(expert, 'w1'):
|
178 |
+
# Check for MX format scale tensors
|
179 |
+
return hasattr(expert.w1, 'scales')
|
180 |
+
except:
|
181 |
+
pass
|
182 |
+
return False
|
183 |
+
|
184 |
+
def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
|
185 |
+
"""Load the base model with proper MX format handling."""
|
186 |
+
print(f"\n{'='*50}")
|
187 |
+
print(f"Loading model: {MODEL_ID}")
|
188 |
+
print(f"MX Format Available: {_HAS_TRITON_KERNELS}")
|
189 |
+
print(f"{'='*50}\n")
|
190 |
|
191 |
+
# Load config to check model type
|
192 |
+
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
|
193 |
|
194 |
+
# Build loading kwargs
|
195 |
+
load_kwargs = {
|
196 |
+
"trust_remote_code": True,
|
197 |
+
"device_map": device_map,
|
198 |
+
"low_cpu_mem_usage": True,
|
199 |
+
"token": HF_TOKEN,
|
200 |
+
"attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
|
201 |
+
}
|
202 |
|
203 |
+
if IS_GPT_OSS:
|
204 |
+
if _HAS_TRITON_KERNELS:
|
205 |
+
print("β Loading with native MX format support")
|
206 |
+
load_kwargs["torch_dtype"] = "auto" # Let model use native MX
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
else:
|
208 |
+
print("β No triton_kernels - falling back to bf16 (dequantized)")
|
209 |
+
print(" This will likely cause LoRA compatibility issues!")
|
210 |
+
load_kwargs["torch_dtype"] = torch.bfloat16
|
211 |
+
else:
|
212 |
+
# Non-GPT-OSS models
|
213 |
+
load_kwargs["torch_dtype"] = torch.bfloat16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
+
# Load the model
|
216 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
|
217 |
|
218 |
+
# Verify format
|
219 |
+
print(f"Model loaded - dtype: {next(model.parameters()).dtype}")
|
220 |
+
if IS_GPT_OSS:
|
221 |
+
is_mx = detect_mx_format(model)
|
222 |
+
if is_mx:
|
223 |
+
print("β Confirmed: Using native MX format")
|
224 |
+
else:
|
225 |
+
print("β Model dequantized to bf16 - LoRA may fail")
|
226 |
|
227 |
+
# Set model config
|
228 |
+
if getattr(model.config, "pad_token_id", None) is None:
|
229 |
+
model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
230 |
+
model.config.use_cache = True
|
231 |
|
|
|
232 |
return model
|
233 |
|
234 |
+
def load_lora_adapter(model, adapter_id: str, subfolder: Optional[str] = None):
|
235 |
+
"""Load and attach LoRA adapter with MX format handling."""
|
236 |
+
if not _HAS_PEFT:
|
237 |
+
raise RuntimeError("PEFT is required for LoRA adapters")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
+
print(f"\n{'='*50}")
|
240 |
+
print(f"Loading LoRA: {adapter_id}")
|
241 |
+
if subfolder:
|
242 |
+
print(f"Subfolder: {subfolder}")
|
243 |
+
print(f"{'='*50}\n")
|
244 |
|
245 |
+
# Check if model is using MX format
|
246 |
+
is_mx = detect_mx_format(model) if IS_GPT_OSS else False
|
247 |
|
248 |
+
# Prepare kwargs for PEFT
|
249 |
+
peft_kwargs = {"token": HF_TOKEN, "is_trainable": False}
|
250 |
+
if subfolder:
|
251 |
+
peft_kwargs["subfolder"] = subfolder
|
|
|
|
|
252 |
|
253 |
+
try:
|
254 |
+
# Load adapter configuration
|
255 |
+
peft_config = PeftConfig.from_pretrained(adapter_id, **peft_kwargs)
|
256 |
+
print(f"LoRA config: r={peft_config.r}, alpha={peft_config.lora_alpha}")
|
257 |
|
258 |
+
# Load the adapter
|
259 |
+
model = PeftModel.from_pretrained(model, adapter_id, **peft_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
+
if not is_mx and IS_GPT_OSS:
|
262 |
+
print("β WARNING: Model is bf16 but LoRA was likely trained on MX format")
|
263 |
+
print(" Reducing LoRA influence to 10% to prevent corruption")
|
264 |
+
|
265 |
+
# Scale down LoRA weights
|
266 |
+
for name, param in model.named_parameters():
|
267 |
+
if 'lora_' in name:
|
268 |
+
param.data *= 0.1
|
269 |
|
270 |
+
print("β LoRA adapter loaded successfully")
|
271 |
+
|
272 |
+
# Optionally merge adapter
|
273 |
+
if MERGE_ADAPTER and hasattr(model, 'merge_and_unload'):
|
274 |
+
print("Merging adapter into base model...")
|
275 |
+
model = model.merge_and_unload()
|
276 |
+
print("β Adapter merged")
|
277 |
+
|
278 |
+
return model
|
279 |
+
|
280 |
+
except Exception as e:
|
281 |
+
print(f"β Failed to load LoRA: {e}")
|
282 |
+
print("Continuing with base model only")
|
283 |
+
return model
|
284 |
+
|
285 |
+
# ===== HARMONY FORMATTING =====
|
286 |
+
|
287 |
+
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high"):
|
288 |
+
"""Create Harmony-formatted prompt."""
|
289 |
+
if not HARMONY_AVAILABLE or not harmony_encoding:
|
290 |
+
# Fallback to chat template
|
291 |
+
if messages and messages[0].get("role") != "system":
|
292 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages
|
293 |
+
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
294 |
|
295 |
+
# Map reasoning effort
|
296 |
+
effort_map = {
|
297 |
+
"low": ReasoningEffort.LOW,
|
298 |
+
"medium": ReasoningEffort.MEDIUM,
|
299 |
+
"high": ReasoningEffort.HIGH
|
300 |
+
}
|
301 |
+
effort = effort_map.get(reasoning_effort.lower(), ReasoningEffort.HIGH)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
303 |
+
# Build Harmony conversation
|
304 |
+
system_content = (
|
305 |
+
SystemContent.new()
|
306 |
+
.with_model_identity("You are ChatGPT, a large language model trained by OpenAI.")
|
307 |
+
.with_reasoning_effort(effort)
|
308 |
+
.with_conversation_start_date(datetime.now().strftime("%Y-%m-%d"))
|
309 |
+
.with_knowledge_cutoff("2024-06")
|
310 |
+
.with_required_channels(REQUIRED_CHANNELS)
|
311 |
+
)
|
312 |
|
313 |
+
# Extract system prompt
|
314 |
+
sys_text = SYSTEM_PROMPT
|
315 |
+
rest = messages or []
|
316 |
+
if rest and rest[0].get("role") == "system":
|
317 |
+
sys_text = rest[0].get("content", SYSTEM_PROMPT)
|
318 |
+
rest = rest[1:]
|
319 |
|
320 |
+
# Build messages
|
321 |
+
harmony_messages = [
|
322 |
+
Message.from_role_and_content(Role.SYSTEM, system_content),
|
323 |
+
Message.from_role_and_content(
|
324 |
+
Role.DEVELOPER,
|
325 |
+
DeveloperContent.new().with_instructions(sys_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
)
|
327 |
+
]
|
328 |
+
|
329 |
+
for msg in rest:
|
330 |
+
role = msg.get("role")
|
331 |
+
content = msg.get("content", "")
|
332 |
+
if role == "user":
|
333 |
+
harmony_messages.append(Message.from_role_and_content(Role.USER, content))
|
334 |
+
elif role == "assistant":
|
335 |
+
harmony_messages.append(
|
336 |
+
Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final")
|
337 |
+
)
|
338 |
+
|
339 |
+
# Render to token IDs
|
340 |
+
convo = Conversation.from_messages(harmony_messages)
|
341 |
+
return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
|
344 |
+
"""Parse Harmony response tokens into channels."""
|
345 |
+
if not HARMONY_AVAILABLE or not harmony_encoding:
|
346 |
text = tokenizer.decode(tokens, skip_special_tokens=False)
|
347 |
+
return {"final": extract_final_channel(text), "raw": text}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
|
|
|
|
|
|
|
|
|
349 |
try:
|
350 |
+
# Parse using Harmony
|
351 |
+
parsed = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
|
352 |
+
|
353 |
+
channels = {}
|
354 |
+
for msg in parsed:
|
355 |
+
channel = getattr(msg, 'channel', 'final')
|
356 |
+
if channel not in channels:
|
357 |
+
channels[channel] = ""
|
358 |
+
|
359 |
+
# Extract text content
|
360 |
+
content = msg.content
|
361 |
+
if isinstance(content, list):
|
362 |
+
text = "".join([getattr(part, "text", str(part)) for part in content])
|
363 |
+
else:
|
364 |
+
text = getattr(content, "text", str(content))
|
365 |
+
|
366 |
+
channels[channel] += text
|
367 |
+
|
368 |
+
# Ensure final channel exists
|
369 |
+
if "final" not in channels:
|
370 |
+
channels["final"] = " ".join(channels.values())
|
371 |
+
|
372 |
+
return channels
|
373 |
+
|
374 |
+
except Exception as e:
|
375 |
+
print(f"Harmony parsing failed: {e}")
|
376 |
+
text = tokenizer.decode(tokens, skip_special_tokens=False)
|
377 |
+
return {"final": extract_final_channel(text), "raw": text}
|
378 |
+
|
379 |
+
def extract_final_channel(text: str) -> str:
|
380 |
+
"""Extract final channel from raw text."""
|
381 |
+
# Look for <|channel|>final<|message|>
|
382 |
+
if "<|channel|>final<|message|>" in text:
|
383 |
+
parts = text.split("<|channel|>final<|message|>")
|
384 |
+
if len(parts) > 1:
|
385 |
+
final = parts[-1]
|
386 |
+
# Truncate at next marker
|
387 |
+
for marker in ["<|channel|>", "<|end|>", "<|return|>"]:
|
388 |
+
if marker in final:
|
389 |
+
final = final.split(marker)[0]
|
390 |
+
return final.strip()
|
391 |
+
|
392 |
+
# Fallback: return cleaned text
|
393 |
+
for marker in ["<|channel|>", "<|message|>", "<|end|>", "<|return|>"]:
|
394 |
+
text = text.replace(marker, " ")
|
395 |
return text.strip()
|
396 |
|
397 |
+
# ===== GENERATION =====
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
|
|
|
|
|
|
|
399 |
@spaces.GPU(duration=120)
|
400 |
+
def generate_on_gpu(
|
401 |
+
prompt,
|
402 |
+
temperature: float,
|
403 |
+
top_p: float,
|
404 |
+
top_k: int,
|
405 |
+
max_new_tokens: int,
|
406 |
+
do_sample: bool,
|
407 |
+
repetition_penalty: float,
|
408 |
+
seed: Optional[int]
|
409 |
+
) -> Dict[str, str]:
|
410 |
+
"""Run generation on GPU."""
|
411 |
try:
|
412 |
+
# Set seed if provided
|
413 |
if seed is not None:
|
414 |
torch.manual_seed(int(seed))
|
|
|
|
|
|
|
415 |
|
416 |
+
# Load model
|
417 |
+
print("\nLoading model for generation...")
|
418 |
+
model = load_base_model("auto")
|
419 |
+
|
420 |
+
# Load LoRA if specified
|
421 |
+
if ADAPTER_ID:
|
422 |
+
model = load_lora_adapter(model, ADAPTER_ID, ADAPTER_SUBFOLDER)
|
423 |
+
|
424 |
+
model.eval()
|
425 |
+
|
426 |
# Prepare inputs
|
427 |
device = next(model.parameters()).device
|
428 |
+
|
429 |
+
if HARMONY_AVAILABLE and isinstance(prompt, list):
|
430 |
+
# Harmony returns token IDs
|
431 |
+
input_ids = torch.tensor([prompt], dtype=torch.long, device=device)
|
|
|
432 |
else:
|
433 |
+
# String prompt
|
434 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
435 |
+
input_ids = inputs["input_ids"].to(device)
|
436 |
+
|
437 |
+
attention_mask = torch.ones_like(input_ids)
|
438 |
+
prompt_len = input_ids.shape[1]
|
439 |
|
440 |
# Generate
|
441 |
+
print("Generating response...")
|
442 |
+
with torch.no_grad():
|
443 |
+
outputs = model.generate(
|
444 |
+
input_ids=input_ids,
|
445 |
+
attention_mask=attention_mask,
|
446 |
+
max_new_tokens=max_new_tokens,
|
447 |
+
temperature=temperature,
|
448 |
+
top_p=top_p,
|
449 |
+
top_k=top_k if top_k > 0 else None,
|
450 |
+
do_sample=do_sample,
|
451 |
+
repetition_penalty=repetition_penalty,
|
452 |
+
pad_token_id=model.config.pad_token_id,
|
453 |
+
eos_token_id=HARMONY_STOP_IDS if HARMONY_STOP_IDS else tokenizer.eos_token_id,
|
454 |
+
no_repeat_ngram_size=3,
|
455 |
+
)
|
|
|
456 |
|
457 |
# Extract generated tokens
|
458 |
+
gen_tokens = outputs[0][prompt_len:].tolist()
|
|
|
459 |
|
460 |
# Truncate at stop tokens
|
461 |
+
for stop_id in HARMONY_STOP_IDS:
|
462 |
+
if stop_id in gen_tokens:
|
463 |
+
gen_tokens = gen_tokens[:gen_tokens.index(stop_id)]
|
464 |
+
break
|
|
|
465 |
|
466 |
# Parse response
|
467 |
+
channels = parse_harmony_response(gen_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
return channels
|
470 |
+
|
471 |
except Exception as e:
|
472 |
+
error_msg = f"Generation failed: {str(e)}\n{traceback.format_exc()}"
|
473 |
+
print(error_msg)
|
474 |
+
return {"final": f"Error: {str(e)}", "raw": error_msg}
|
475 |
+
|
476 |
finally:
|
477 |
# Cleanup
|
478 |
+
if 'model' in locals():
|
479 |
del model
|
|
|
|
|
480 |
gc.collect()
|
481 |
if torch.cuda.is_available():
|
482 |
torch.cuda.empty_cache()
|
483 |
|
484 |
+
# ===== GRADIO INTERFACE =====
|
485 |
+
|
486 |
+
def chat_response(
|
487 |
+
message: str,
|
488 |
+
history: List[List[str]],
|
489 |
+
system_prompt: str,
|
490 |
+
temperature: float,
|
491 |
+
top_p: float,
|
492 |
+
top_k: int,
|
493 |
+
max_new_tokens: int,
|
494 |
+
do_sample: bool,
|
495 |
+
repetition_penalty: float,
|
496 |
+
seed: Optional[int],
|
497 |
+
reasoning_effort: str,
|
498 |
+
show_thinking: bool
|
499 |
+
) -> str:
|
500 |
+
"""Handle chat interaction."""
|
501 |
try:
|
502 |
+
# Build conversation
|
503 |
+
messages = [{"role": "system", "content": system_prompt or SYSTEM_PROMPT}]
|
504 |
|
505 |
+
# Add history
|
506 |
+
for turn in history or []:
|
507 |
+
if isinstance(turn, (list, tuple)) and len(turn) >= 2:
|
508 |
+
user_msg, assistant_msg = turn[0], turn[1]
|
509 |
+
if user_msg:
|
510 |
+
messages.append({"role": "user", "content": str(user_msg)})
|
511 |
+
if assistant_msg:
|
512 |
+
messages.append({"role": "assistant", "content": str(assistant_msg)})
|
513 |
|
514 |
+
# Add current message
|
515 |
+
messages.append({"role": "user", "content": message})
|
516 |
|
517 |
# Create prompt
|
518 |
+
prompt = create_harmony_prompt(messages, reasoning_effort)
|
519 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
# Generate
|
521 |
+
channels = generate_on_gpu(
|
522 |
prompt,
|
523 |
+
temperature,
|
524 |
+
top_p,
|
525 |
+
top_k,
|
526 |
+
max_new_tokens,
|
527 |
+
do_sample,
|
528 |
+
repetition_penalty,
|
529 |
+
seed
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
)
|
531 |
|
532 |
# Format response
|
533 |
+
if show_thinking and len(channels) > 1:
|
534 |
response = "## Chain of Thought:\n\n"
|
535 |
for channel, content in channels.items():
|
536 |
if channel != "final" and content:
|
537 |
+
response += f"### {channel.capitalize()}:\n{content}\n\n"
|
538 |
+
response += f"### Final Response:\n{channels.get('final', 'No response generated')}"
|
|
|
539 |
else:
|
540 |
+
response = channels.get("final", "No response generated")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
|
542 |
+
return response
|
|
|
|
|
543 |
|
544 |
+
except Exception as e:
|
545 |
+
return f"Error: {str(e)}"
|
|
|
546 |
|
547 |
+
# ===== BUILD UI =====
|
548 |
+
|
549 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Mirel") as demo:
|
550 |
+
# Header with status
|
551 |
+
status_mx = "β
MX Format" if _HAS_TRITON_KERNELS else "β No MX Support"
|
552 |
+
status_harmony = "β
Harmony" if HARMONY_AVAILABLE else "β No Harmony"
|
553 |
+
|
554 |
+
gr.Markdown(f"""
|
555 |
+
# π€ Mirel β Chain-of-Thought Assistant
|
556 |
+
|
557 |
+
**Model:** `{MODEL_ID}` | **Adapter:** `{ADAPTER_ID or 'None'}`
|
558 |
+
**Status:** {status_mx} | {status_harmony} | {"β
ZeroGPU" if ZEROGPU else "CPU Mode"}
|
559 |
|
560 |
+
{'''
|
561 |
+
β οΈ **WARNING: MX Format Support Missing!**
|
562 |
+
Install with: `pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels`
|
563 |
+
''' if IS_GPT_OSS and not _HAS_TRITON_KERNELS else ''}
|
564 |
+
""")
|
565 |
+
|
566 |
+
# System prompt
|
567 |
+
system_prompt = gr.Textbox(
|
568 |
+
label="System Prompt",
|
569 |
+
value=SYSTEM_PROMPT,
|
570 |
+
lines=2
|
571 |
+
)
|
572 |
+
|
573 |
+
# Settings
|
574 |
+
with gr.Accordion("βοΈ Generation Settings", open=False):
|
575 |
with gr.Row():
|
576 |
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
577 |
+
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
|
578 |
+
top_k = gr.Slider(0, 200, value=50, step=1, label="Top-k")
|
579 |
+
|
580 |
with gr.Row():
|
581 |
+
max_new_tokens = gr.Slider(16, 2048, value=MAX_NEW_TOKENS, step=16, label="Max tokens")
|
582 |
+
repetition_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label="Repetition penalty")
|
583 |
seed = gr.Number(value=None, label="Seed (optional)", precision=0)
|
584 |
+
|
585 |
with gr.Row():
|
586 |
+
do_sample = gr.Checkbox(value=True, label="Sample")
|
587 |
+
show_thinking = gr.Checkbox(value=False, label="Show thinking channels")
|
588 |
reasoning_effort = gr.Radio(
|
589 |
+
["low", "medium", "high"],
|
590 |
value="high",
|
591 |
+
label="Reasoning effort"
|
|
|
|
|
|
|
|
|
|
|
|
|
592 |
)
|
593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
594 |
# Chat interface
|
595 |
chat = gr.ChatInterface(
|
596 |
+
fn=chat_response,
|
|
|
597 |
additional_inputs=[
|
598 |
+
system_prompt,
|
599 |
+
temperature,
|
600 |
+
top_p,
|
601 |
+
top_k,
|
602 |
+
max_new_tokens,
|
603 |
+
do_sample,
|
604 |
+
repetition_penalty,
|
605 |
+
seed,
|
606 |
+
reasoning_effort,
|
607 |
+
show_thinking
|
608 |
],
|
609 |
+
title=None,
|
|
|
610 |
examples=[
|
611 |
["Hello! Can you introduce yourself?"],
|
612 |
+
["What's the capital of France?"],
|
613 |
+
["Explain quantum computing simply"],
|
614 |
+
["Write a haiku about coding"],
|
615 |
],
|
616 |
cache_examples=False,
|
617 |
)
|
618 |
+
|
619 |
+
# Footer
|
620 |
+
gr.Markdown("""
|
621 |
+
---
|
622 |
+
π‘ **Tips:**
|
623 |
+
- Enable "Show thinking channels" to see the model's reasoning process
|
624 |
+
- Adjust "Reasoning effort" for faster responses (low) or better quality (high)
|
625 |
+
- The model uses MX format on H200 GPUs for optimal performance
|
626 |
+
""")
|
627 |
+
|
628 |
+
# ===== LAUNCH =====
|
|
|
629 |
if __name__ == "__main__":
|
630 |
+
print("\n" + "="*60)
|
631 |
+
print("MIREL READY TO LAUNCH")
|
632 |
+
print(f"Model: {MODEL_ID}")
|
633 |
+
print(f"Adapter: {ADAPTER_ID or 'None'}")
|
634 |
+
print(f"MX Format: {'ENABLED' if _HAS_TRITON_KERNELS else 'DISABLED'}")
|
635 |
+
print(f"Harmony: {'ENABLED' if HARMONY_AVAILABLE else 'DISABLED'}")
|
636 |
+
print("="*60 + "\n")
|
637 |
+
|
638 |
+
demo.queue(max_size=10).launch(
|
639 |
+
server_name="0.0.0.0",
|
640 |
server_port=7860,
|
641 |
share=False
|
642 |
)
|
install.sh
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Complete installation script for Mirel with MX format support on H200
|
3 |
+
|
4 |
+
echo "Installing Mirel dependencies for GPT-OSS with MX format support..."
|
5 |
+
|
6 |
+
# Upgrade pip first
|
7 |
+
pip install --upgrade pip
|
8 |
+
|
9 |
+
# Install main requirements
|
10 |
+
pip install huggingface_hub>=0.34.0
|
11 |
+
pip install transformers>=4.55.0
|
12 |
+
pip install accelerate>=0.33.0
|
13 |
+
pip install torch>=2.4.0
|
14 |
+
pip install gradio>=5.42.0
|
15 |
+
pip install spaces
|
16 |
+
|
17 |
+
# Install LoRA/PEFT support
|
18 |
+
pip install peft>=0.11.0
|
19 |
+
pip install bitsandbytes>=0.43.1
|
20 |
+
|
21 |
+
# Install Harmony format
|
22 |
+
pip install openai-harmony
|
23 |
+
|
24 |
+
# Install Triton and MX format support
|
25 |
+
pip install triton>=3.4.0
|
26 |
+
|
27 |
+
# CRITICAL: Install triton_kernels from git subdirectory
|
28 |
+
# This is REQUIRED for MX format on H200 GPUs
|
29 |
+
echo "Installing triton_kernels (REQUIRED for MX format)..."
|
30 |
+
pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
|
31 |
+
|
32 |
+
# Optional but recommended
|
33 |
+
pip install safetensors>=0.4.0
|
34 |
+
pip install sentencepiece>=0.2.0
|
35 |
+
pip install protobuf>=3.20.0
|
36 |
+
pip install "numpy<2.0.0"
|
37 |
+
|
38 |
+
# Verify critical imports
|
39 |
+
echo "Verifying installation..."
|
40 |
+
python -c "
|
41 |
+
import sys
|
42 |
+
errors = []
|
43 |
+
|
44 |
+
try:
|
45 |
+
import torch
|
46 |
+
print(f'β PyTorch {torch.__version__}')
|
47 |
+
except ImportError as e:
|
48 |
+
errors.append(f'β PyTorch: {e}')
|
49 |
+
|
50 |
+
try:
|
51 |
+
import transformers
|
52 |
+
print(f'β Transformers {transformers.__version__}')
|
53 |
+
except ImportError as e:
|
54 |
+
errors.append(f'β Transformers: {e}')
|
55 |
+
|
56 |
+
try:
|
57 |
+
import peft
|
58 |
+
print(f'β PEFT {peft.__version__}')
|
59 |
+
except ImportError as e:
|
60 |
+
errors.append(f'β PEFT: {e}')
|
61 |
+
|
62 |
+
try:
|
63 |
+
import triton
|
64 |
+
print(f'β Triton {triton.__version__}')
|
65 |
+
except ImportError as e:
|
66 |
+
errors.append(f'β Triton: {e}')
|
67 |
+
|
68 |
+
try:
|
69 |
+
import triton_kernels
|
70 |
+
print('β Triton Kernels (MX format support)')
|
71 |
+
except ImportError as e:
|
72 |
+
errors.append(f'β Triton Kernels (CRITICAL): {e}')
|
73 |
+
print('β οΈ WARNING: MX format will NOT work without triton_kernels!')
|
74 |
+
|
75 |
+
try:
|
76 |
+
import openai_harmony
|
77 |
+
print('β OpenAI Harmony')
|
78 |
+
except ImportError as e:
|
79 |
+
errors.append(f'β OpenAI Harmony: {e}')
|
80 |
+
|
81 |
+
try:
|
82 |
+
import gradio
|
83 |
+
print(f'β Gradio {gradio.__version__}')
|
84 |
+
except ImportError as e:
|
85 |
+
errors.append(f'β Gradio: {e}')
|
86 |
+
|
87 |
+
if errors:
|
88 |
+
print('\nβ Installation issues found:')
|
89 |
+
for error in errors:
|
90 |
+
print(f' {error}')
|
91 |
+
sys.exit(1)
|
92 |
+
else:
|
93 |
+
print('\nβ
All dependencies installed successfully!')
|
94 |
+
print('You can now run the Mirel app with MX format support on H200 GPUs')
|
95 |
+
"
|
96 |
+
|
97 |
+
echo "Installation complete!"
|
requirements.txt
CHANGED
@@ -1,10 +1,25 @@
|
|
|
|
1 |
huggingface_hub>=0.34.0
|
2 |
transformers>=4.55.0
|
3 |
accelerate>=0.33.0
|
|
|
|
|
|
|
|
|
|
|
4 |
peft>=0.11.0
|
5 |
-
torch>=2.4.0 # ZeroGPU-supported (2.3.x is NOT supported)
|
6 |
bitsandbytes>=0.43.1
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
9 |
triton>=3.4.0
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core dependencies
|
2 |
huggingface_hub>=0.34.0
|
3 |
transformers>=4.55.0
|
4 |
accelerate>=0.33.0
|
5 |
+
torch>=2.4.0
|
6 |
+
gradio>=5.42.0
|
7 |
+
spaces
|
8 |
+
|
9 |
+
# LoRA/PEFT support
|
10 |
peft>=0.11.0
|
|
|
11 |
bitsandbytes>=0.43.1
|
12 |
+
|
13 |
+
# Harmony format for OpenAI GPT-OSS models
|
14 |
+
openai-harmony
|
15 |
+
|
16 |
+
# MX format support (REQUIRED for GPT-OSS-20B on H200)
|
17 |
triton>=3.4.0
|
18 |
+
# Note: triton_kernels must be installed separately from git:
|
19 |
+
# pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
|
20 |
+
|
21 |
+
# Optional but recommended
|
22 |
+
safetensors>=0.4.0
|
23 |
+
sentencepiece>=0.2.0
|
24 |
+
protobuf>=3.20.0
|
25 |
+
numpy<2.0.0 # Some dependencies may not support numpy 2.x yet
|
setup.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
setup.py - Run this at the start of app.py to ensure triton_kernels is installed
|
3 |
+
Add this to the top of your app.py file in HF Spaces
|
4 |
+
"""
|
5 |
+
|
6 |
+
import subprocess
|
7 |
+
import sys
|
8 |
+
|
9 |
+
def ensure_triton_kernels():
|
10 |
+
"""Ensure triton_kernels is installed for MX format support."""
|
11 |
+
try:
|
12 |
+
import triton_kernels
|
13 |
+
print("β triton_kernels already installed")
|
14 |
+
return True
|
15 |
+
except ImportError:
|
16 |
+
print("Installing triton_kernels for MX format support...")
|
17 |
+
try:
|
18 |
+
subprocess.check_call([
|
19 |
+
sys.executable, "-m", "pip", "install",
|
20 |
+
"git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
|
21 |
+
])
|
22 |
+
print("β triton_kernels installed successfully")
|
23 |
+
return True
|
24 |
+
except subprocess.CalledProcessError as e:
|
25 |
+
print(f"β Failed to install triton_kernels: {e}")
|
26 |
+
print("WARNING: MX format will fall back to bf16, LoRA may not work!")
|
27 |
+
return False
|
28 |
+
|
29 |
+
# Run at import time
|
30 |
+
if __name__ != "__main__": # When imported
|
31 |
+
ensure_triton_kernels()
|