AbstractPhil commited on
Commit
6228595
Β·
1 Parent(s): 4ab6146
Files changed (4) hide show
  1. app.py +478 -637
  2. install.sh +97 -0
  3. requirements.txt +19 -4
  4. setup.py +31 -0
app.py CHANGED
@@ -3,9 +3,54 @@ Mirel Harmony Inference – HF Space (Gradio)
3
  ZeroGPU-ready, Harmony formatting, MX format support for GPT-OSS-20B
4
  Proper LoRA adapter loading and conversion for MX compatibility
5
  Single file: app.py
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from __future__ import annotations
8
- import os, gc, json, threading, torch, warnings
9
  from dataclasses import dataclass
10
  from typing import List, Dict, Optional, Any, Union
11
  from datetime import datetime
@@ -14,7 +59,7 @@ import spaces # required for ZeroGPU
14
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
15
  import numpy as np
16
 
17
- # Suppress warnings about MX format
18
  warnings.filterwarnings("ignore", message=".*microscaling.*")
19
  warnings.filterwarnings("ignore", message=".*mx.*")
20
 
@@ -32,770 +77,566 @@ try:
32
  ReasoningEffort
33
  )
34
  HARMONY_AVAILABLE = True
 
35
  except ImportError:
36
- print("[WARNING] openai_harmony not installed. Install with: pip install openai-harmony")
37
  HARMONY_AVAILABLE = False
38
 
39
- # -----------------------
40
- # Config & runtime modes
41
- # -----------------------
42
- # MX format uses special dtypes - we need to handle this properly
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
44
- ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b") # Default to your adapter
45
- ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516") # Default to the subfolder
46
  ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
47
- SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
48
- MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
49
- ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
 
50
 
51
- # For GPT-OSS models, we need specific handling
52
  IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower()
53
- USE_MX_FORMAT = os.getenv("USE_MX_FORMAT", "1" if IS_GPT_OSS else "0") == "1"
54
 
55
- # Harmony channels for CoT
56
  REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
57
 
58
- # HF Auth
59
- HF_TOKEN: Optional[str] = (
60
  os.getenv("HF_TOKEN")
61
  or os.getenv("HUGGING_FACE_HUB_TOKEN")
62
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
63
  or os.getenv("HF_ACCESS_TOKEN")
64
  )
65
 
66
- def _hf_login() -> None:
67
- """Login to HF Hub using common env secret names."""
68
  if HF_TOKEN:
69
  try:
70
  from huggingface_hub import login, whoami
71
  login(token=HF_TOKEN, add_to_git_credential=True)
72
  try:
73
- who = whoami(token=HF_TOKEN)
74
- print(f"[HF Auth] Logged in as: {who.get('name') or who.get('fullname') or who.get('id', 'unknown')}")
75
- except Exception:
76
- print("[HF Auth] Login successful but couldn't get user info")
77
  except Exception as e:
78
- print(f"[HF Auth] Login failed: {e}")
79
  else:
80
- print("[HF Auth] No token found in environment variables")
81
 
82
- # Login before loading any models
83
  _hf_login()
84
 
 
85
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
86
 
87
- # Load Harmony encoding if available
88
- if HARMONY_AVAILABLE:
89
- harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
90
- else:
91
- harmony_encoding = None
92
-
93
- # Stop tokens per Harmony spec: <|return|> (200002), <|call|> (200012)
94
- HARMONY_STOP_IDS = harmony_encoding.stop_tokens_for_assistant_actions() if HARMONY_AVAILABLE else []
95
-
96
- # Tokenizer is lightweight; load once
97
  try:
98
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
99
- print(f"[Model] Successfully loaded tokenizer from {MODEL_ID}")
100
  except Exception as e:
101
- print(f"[Model] Failed to load tokenizer: {e}")
102
  raise
103
 
104
- # -----------------------
105
- # PEFT and MX Format Support
106
- # -----------------------
107
- try:
108
- from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
109
- _HAS_PEFT = True
110
- except Exception:
111
- _HAS_PEFT = False
112
- print("[Warning] PEFT not available. Install with: pip install peft")
113
 
114
- # Try to import microscaling support if available
115
- try:
116
- import msamp
117
- _HAS_MSAMP = True
118
- print("[Info] Microsoft AMP (msamp) available for MX format support")
119
- except ImportError:
120
- _HAS_MSAMP = False
121
- print("[Info] msamp not available - using fallback MX handling")
122
-
123
- # -----------------------
124
- # MX Format Conversion
125
- # -----------------------
126
- def convert_fp32_lora_to_mx_compatible(lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
127
- """
128
- Convert fp32 LoRA weights to be compatible with MX format base model.
129
- MX models expect specific dtype handling.
130
- """
131
- converted = {}
132
-
133
- for key, tensor in lora_state_dict.items():
134
- if tensor is None:
135
- converted[key] = tensor
136
- continue
137
-
138
- # LoRA weights (lora_A, lora_B) need special handling
139
- if 'lora_' in key:
140
- # For MX compatibility, we keep weights in fp32 but ensure proper scaling
141
- # MX format internally handles quantization, we just need clean fp32 inputs
142
- if tensor.dtype != torch.float32:
143
- tensor = tensor.to(torch.float32)
144
-
145
- # Ensure weights are in reasonable range for MX quantization
146
- # MX format works best with weights in [-1, 1] range
147
- if 'lora_A' in key:
148
- # Input projection - initialize with small values
149
- std = 1.0 / torch.sqrt(torch.tensor(tensor.shape[1], dtype=torch.float32))
150
- if tensor.std() > std * 10: # If weights are too large
151
- print(f"[MX Convert] Scaling down {key} from std={tensor.std():.4f} to {std:.4f}")
152
- tensor = tensor * (std / tensor.std())
153
- elif 'lora_B' in key:
154
- # Output projection - should be near zero initially
155
- if tensor.abs().max() > 0.1:
156
- print(f"[MX Convert] Scaling down {key} max={tensor.abs().max():.4f}")
157
- tensor = tensor * 0.01
158
-
159
- converted[key] = tensor
160
- else:
161
- # Non-LoRA weights (like embeddings) stay as-is
162
- converted[key] = tensor
163
-
164
- return converted
165
 
166
- def prepare_model_for_mx_lora(model, adapter_path: str, subfolder: Optional[str] = None):
167
- """
168
- Prepare and attach LoRA adapter to MX format model.
169
- Handles the special requirements of GPT-OSS MX models.
170
- """
171
- if not _HAS_PEFT:
172
- raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft")
173
 
174
- # Build the full path including subfolder
175
- full_adapter_path = adapter_path
176
- if subfolder:
177
- print(f"[LoRA] Loading adapter from {adapter_path} (subfolder: {subfolder})")
178
- else:
179
- print(f"[LoRA] Loading adapter from {adapter_path}")
180
-
181
- # Load the LoRA config with subfolder support
182
- peft_kwargs = {"token": HF_TOKEN}
183
- if subfolder:
184
- peft_kwargs["subfolder"] = subfolder
 
 
 
 
 
 
185
 
186
- peft_config = PeftConfig.from_pretrained(adapter_path, **peft_kwargs)
 
187
 
188
- # Load the LoRA weights - need to check in the right location
189
- from safetensors.torch import load_file
190
- import os.path as osp
191
- from huggingface_hub import hf_hub_download
 
 
 
 
192
 
193
- try:
194
- # Try to download from HF Hub with subfolder
195
- if subfolder:
196
- # Download the adapter weights file
197
- try:
198
- adapter_weights_path = hf_hub_download(
199
- repo_id=adapter_path,
200
- filename="adapter_model.safetensors",
201
- subfolder=subfolder,
202
- token=HF_TOKEN
203
- )
204
- adapter_weights = load_file(adapter_weights_path)
205
- print(f"[LoRA] Loaded safetensors weights from {subfolder}")
206
- except Exception:
207
- # Try .bin format
208
- adapter_weights_path = hf_hub_download(
209
- repo_id=adapter_path,
210
- filename="adapter_model.bin",
211
- subfolder=subfolder,
212
- token=HF_TOKEN
213
- )
214
- adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
215
- print(f"[LoRA] Loaded bin weights from {subfolder}")
216
  else:
217
- # No subfolder - try local path first, then HF Hub
218
- local_safetensors = osp.join(adapter_path, "adapter_model.safetensors")
219
- local_bin = osp.join(adapter_path, "adapter_model.bin")
220
-
221
- if osp.exists(local_safetensors):
222
- adapter_weights = load_file(local_safetensors)
223
- print("[LoRA] Loaded local safetensors weights")
224
- elif osp.exists(local_bin):
225
- adapter_weights = torch.load(local_bin, map_location="cpu")
226
- print("[LoRA] Loaded local bin weights")
227
- else:
228
- # Try downloading from HF Hub
229
- try:
230
- adapter_weights_path = hf_hub_download(
231
- repo_id=adapter_path,
232
- filename="adapter_model.safetensors",
233
- token=HF_TOKEN
234
- )
235
- adapter_weights = load_file(adapter_weights_path)
236
- print("[LoRA] Downloaded safetensors weights from Hub")
237
- except Exception:
238
- adapter_weights_path = hf_hub_download(
239
- repo_id=adapter_path,
240
- filename="adapter_model.bin",
241
- token=HF_TOKEN
242
- )
243
- adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
244
- print("[LoRA] Downloaded bin weights from Hub")
245
-
246
- except Exception as e:
247
- raise FileNotFoundError(f"Could not load adapter weights: {e}")
248
-
249
- # Convert weights for MX compatibility
250
- print("[LoRA] Converting fp32 weights for MX format compatibility...")
251
- adapter_weights = convert_fp32_lora_to_mx_compatible(adapter_weights)
252
 
253
- # Create PEFT model with special handling for MX
254
- print("[LoRA] Attaching LoRA to base model...")
255
 
256
- # For MX models, we need to be careful about dtype
257
- # The base model uses MX format internally, but the interface should be fp32
258
- model = PeftModel.from_pretrained(
259
- model,
260
- adapter_path,
261
- is_trainable=False,
262
- **peft_kwargs # This includes token and subfolder
263
- )
264
 
265
- # Manually update the adapter weights with our converted versions
266
- model.load_state_dict(adapter_weights, strict=False)
 
 
267
 
268
- print("[LoRA] Successfully attached LoRA adapter with MX compatibility")
269
  return model
270
 
271
- # -----------------------
272
- # Model loading with MX support
273
- # -----------------------
274
- def _build_model_kwargs(device_map: Optional[str]) -> Dict[str, Any]:
275
- """Build kwargs for model loading with MX format support."""
276
- kw: Dict[str, Any] = dict(
277
- device_map=device_map,
278
- trust_remote_code=True,
279
- low_cpu_mem_usage=True,
280
- token=HF_TOKEN,
281
- )
282
-
283
- if IS_GPT_OSS and USE_MX_FORMAT:
284
- # GPT-OSS models use MX format
285
- # Don't specify torch_dtype - let the model use its native MX format
286
- print("[Model] Using MX format for GPT-OSS model")
287
- kw.update({
288
- "attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
289
- # MX models handle their own dtype internally
290
- # Don't force a dtype here
291
- })
292
- else:
293
- # Non-MX models
294
- kw.update({
295
- "torch_dtype": torch.float16, # Use fp16 for non-MX models
296
- "attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
297
- })
298
 
299
- return kw
300
-
301
- def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
302
- """Load model with proper MX format handling."""
303
- print(f"[Model] Loading base model from {MODEL_ID}...")
304
 
305
- # Load config first to check for MX format
306
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
307
 
308
- # Check if this is an MX model
309
- is_mx_model = (
310
- IS_GPT_OSS or
311
- hasattr(config, 'quantization_config') and 'mx' in str(config.quantization_config).lower() or
312
- hasattr(config, 'torch_dtype') and 'mx' in str(config.torch_dtype).lower()
313
- )
314
 
315
- if is_mx_model:
316
- print("[Model] Detected MX format model - using special loading")
 
 
317
 
318
- # For MX models, we need special handling
319
- # The model internally uses MX quantization
320
- model = AutoModelForCausalLM.from_pretrained(
321
- MODEL_ID,
322
- config=config,
323
- trust_remote_code=True,
324
- device_map=device_map,
325
- low_cpu_mem_usage=True,
326
- token=HF_TOKEN,
327
- # Let the model handle its own dtype
328
- attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager",
329
- )
330
 
331
- # Verify the model loaded correctly
332
- print(f"[Model] Model dtype: {next(model.parameters()).dtype}")
333
- print(f"[Model] Model device: {next(model.parameters()).device}")
 
 
 
 
 
334
 
335
- else:
336
- # Standard model loading
337
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **_build_model_kwargs(device_map))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
- # Load and attach LoRA adapter if specified
340
- if ADAPTER_ID:
341
- try:
342
- if is_mx_model:
343
- # Use special MX-compatible LoRA loading with subfolder support
344
- model = prepare_model_for_mx_lora(model, ADAPTER_ID, ADAPTER_SUBFOLDER)
345
- else:
346
- # Standard PEFT loading for non-MX models
347
- if not _HAS_PEFT:
348
- raise RuntimeError("PEFT is required when ADAPTER_ID is set.")
349
- print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...")
350
- peft_kwargs = {"token": HF_TOKEN, "is_trainable": False}
351
- if ADAPTER_SUBFOLDER:
352
- peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
353
- print(f"[Model] Using subfolder: {ADAPTER_SUBFOLDER}")
354
- model = PeftModel.from_pretrained(
355
- model,
356
- ADAPTER_ID,
357
- **peft_kwargs
358
- )
359
-
360
- print("[Model] Successfully loaded with LoRA adapter")
361
-
362
- # Optionally merge adapter for better performance
363
- merge_adapter = os.getenv("MERGE_ADAPTER", "0") == "1"
364
- if merge_adapter and hasattr(model, 'merge_and_unload'):
365
- print("[Model] Merging adapter into base model...")
366
- model = model.merge_and_unload()
367
- print("[Model] Adapter merged successfully")
368
-
369
- except Exception as e:
370
- print(f"[Error] Failed to load adapter: {e}")
371
- print("[Warning] Continuing with base model only")
372
 
373
- model.eval()
 
 
 
 
 
 
 
 
374
 
375
- # Ensure proper config
376
- if getattr(model.config, "pad_token_id", None) is None:
377
- model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
378
- model.config.use_cache = True
 
 
379
 
380
- print(f"[Model] Model loaded successfully - Type: {'MX Format' if is_mx_model else 'Standard'}")
381
- return model
382
-
383
- # -----------------------
384
- # Harmony formatting
385
- # -----------------------
386
- def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
387
- """Build a Harmony-formatted prompt."""
388
- if HARMONY_AVAILABLE and harmony_encoding is not None:
389
- effort_map = {"low": ReasoningEffort.LOW, "medium": ReasoningEffort.MEDIUM, "high": ReasoningEffort.HIGH}
390
- effort = effort_map.get(str(reasoning_effort).lower(), ReasoningEffort.HIGH)
391
-
392
- system_content = (
393
- SystemContent.new()
394
- .with_model_identity("You are ChatGPT, a large language model trained by OpenAI.")
395
- .with_reasoning_effort(effort)
396
- .with_conversation_start_date(datetime.now().strftime("%Y-%m-%d"))
397
- .with_knowledge_cutoff("2024-06")
398
- .with_required_channels(REQUIRED_CHANNELS)
399
  )
400
-
401
- sys_text = SYSTEM_DEF
402
- rest: List[Dict[str, str]] = messages or []
403
- if rest and rest[0].get("role") == "system":
404
- sys_text = rest[0].get("content") or SYSTEM_DEF
405
- rest = rest[1:]
406
-
407
- harmony_messages = [Message.from_role_and_content(Role.SYSTEM, system_content)]
408
- dev = DeveloperContent.new().with_instructions(sys_text)
409
- harmony_messages.append(Message.from_role_and_content(Role.DEVELOPER, dev))
410
-
411
- for m in rest:
412
- role = m.get("role"); content = m.get("content", "")
413
- if role == "user":
414
- harmony_messages.append(Message.from_role_and_content(Role.USER, content))
415
- elif role == "assistant":
416
- harmony_messages.append(
417
- Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final")
418
- )
419
-
420
- convo = Conversation.from_messages(harmony_messages)
421
- return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
422
-
423
- # Fallback: tokenizer chat template
424
- if not messages or messages[0].get("role") != "system":
425
- messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
426
- return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
427
 
428
  def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
429
- """Parse response tokens using Harmony format to extract channels."""
430
- if not HARMONY_AVAILABLE:
431
  text = tokenizer.decode(tokens, skip_special_tokens=False)
432
- return {"final": extract_final_channel_fallback(text), "raw": text}
433
-
434
- parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
435
-
436
- channels = {}
437
- for msg in parsed_messages:
438
- channel = msg.channel if hasattr(msg, 'channel') else "final"
439
- if channel not in channels:
440
- channels[channel] = ""
441
- channels[channel] += "".join([getattr(part, "text", str(part)) for part in (msg.content if isinstance(msg.content, list) else [msg.content])])
442
-
443
- if "final" not in channels:
444
- channels["final"] = " ".join(channels.values())
445
 
446
- return channels
447
-
448
- def extract_final_channel_fallback(text: str) -> str:
449
- """Extract the <final> channel from decoded Harmony text."""
450
  try:
451
- chunks: Dict[str, str] = {}
452
- pieces = text.split("<|channel|>")
453
- for seg in pieces[1:]:
454
- name_end = seg.find("<|message|>")
455
- if name_end <= 0:
456
- continue
457
- ch = seg[:name_end].strip()
458
- body_start = name_end + len("<|message|>")
459
- next_pos = len(seg)
460
- for delim in ("<|channel|>", "<|end|>", "<|return|>"):
461
- p = seg.find(delim, body_start)
462
- if p != -1:
463
- next_pos = min(next_pos, p)
464
- body = seg[body_start:next_pos]
465
- chunks[ch] = chunks.get(ch, "") + body
466
- final_txt = (chunks.get("final", "").strip())
467
- if final_txt:
468
- return final_txt
469
- if "<|channel|>final<|message|>" in text:
470
- tail = text.split("<|channel|>final<|message|>")[-1]
471
- for delim in ("<|return|>", "<|end|>", "<|channel|>"):
472
- idx = tail.find(delim)
473
- if idx != -1:
474
- tail = tail[:idx]
475
- break
476
- return tail.strip()
477
- except Exception:
478
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  return text.strip()
480
 
481
- # -----------------------
482
- # Rose guidance
483
- # -----------------------
484
- def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
485
- """Create vocab bias from {token: weight}."""
486
- vocab_size = len(tokenizer)
487
- bias = torch.zeros(vocab_size, dtype=torch.float32)
488
- for tok, w in mapping.items():
489
- if tok is None:
490
- continue
491
- tid = tokenizer.convert_tokens_to_ids(tok)
492
- if isinstance(tid, list):
493
- for t in tid:
494
- if isinstance(t, int) and t >= 0:
495
- bias[t] += float(w) / max(1, len(tid))
496
- elif isinstance(tid, int) and tid >= 0:
497
- bias[tid] += float(w)
498
- return bias
499
-
500
- class RoseGuidedLogits(torch.nn.Module):
501
- def __init__(self, bias_vec: torch.Tensor, alpha: float = 1.0):
502
- super().__init__()
503
- self.bias_vec = bias_vec
504
- self.alpha = float(alpha)
505
-
506
- def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
507
- return scores + self.alpha * self.bias_vec.to(scores.device)
508
 
509
- # -----------------------
510
- # Generation
511
- # -----------------------
512
  @spaces.GPU(duration=120)
513
- def zerogpu_generate(full_prompt,
514
- gen_kwargs: Dict[str, Any],
515
- rose_map: Optional[Dict[str, float]],
516
- rose_alpha: float,
517
- rose_score: Optional[float],
518
- seed: Optional[int]) -> Dict[str, str]:
519
- """Run inference on GPU with MX format support."""
 
 
 
 
520
  try:
 
521
  if seed is not None:
522
  torch.manual_seed(int(seed))
523
-
524
- # Load model with MX support
525
- model = _load_model_on("auto")
526
 
527
- # Setup logits processor for Rose guidance
528
- logits_processor = None
529
- if rose_map:
530
- bias = build_bias_from_tokens(tokenizer, rose_map).to(next(model.parameters()).device)
531
- eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
532
- logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
533
-
 
 
 
534
  # Prepare inputs
535
  device = next(model.parameters()).device
536
- if HARMONY_AVAILABLE and isinstance(full_prompt, list):
537
- input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
538
- attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
539
- inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
540
- prompt_len = input_ids.shape[1]
541
  else:
542
- enc = tokenizer(full_prompt, return_tensors="pt")
543
- inputs = enc.to(device)
544
- prompt_len = int(inputs["input_ids"].shape[1])
545
- if "attention_mask" not in inputs:
546
- inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
 
547
 
548
  # Generate
549
- eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
550
-
551
- out_ids = model.generate(
552
- **inputs,
553
- do_sample=bool(gen_kwargs.get("do_sample", True)),
554
- temperature=float(gen_kwargs.get("temperature", 0.7)),
555
- top_p=float(gen_kwargs.get("top_p", 0.9)),
556
- top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
557
- max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
558
- pad_token_id=model.config.pad_token_id,
559
- eos_token_id=eos_ids,
560
- logits_processor=logits_processor,
561
- repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
562
- no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
563
- min_new_tokens=1,
564
- )
565
 
566
  # Extract generated tokens
567
- out_list = out_ids[0].tolist()
568
- gen_ids = out_list[prompt_len:]
569
 
570
  # Truncate at stop tokens
571
- if HARMONY_AVAILABLE:
572
- for sid in HARMONY_STOP_IDS:
573
- if sid in gen_ids:
574
- gen_ids = gen_ids[:gen_ids.index(sid)]
575
- break
576
 
577
  # Parse response
578
- if HARMONY_AVAILABLE:
579
- try:
580
- channels = parse_harmony_response(gen_ids)
581
- except Exception:
582
- decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
583
- channels = {
584
- "final": extract_final_channel_fallback(decoded),
585
- "raw": decoded
586
- }
587
- else:
588
- decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
589
- channels = {
590
- "final": extract_final_channel_fallback(decoded),
591
- "raw": decoded
592
- }
593
 
594
  return channels
595
-
596
  except Exception as e:
597
- import traceback
598
- error_trace = traceback.format_exc()
599
- print(f"[Error] Generation failed:\n{error_trace}")
600
- return {"final": f"[Error] {type(e).__name__}: {str(e)}", "raw": error_trace}
601
  finally:
602
  # Cleanup
603
- try:
604
  del model
605
- except:
606
- pass
607
  gc.collect()
608
  if torch.cuda.is_available():
609
  torch.cuda.empty_cache()
610
 
611
- # -----------------------
612
- # Gradio handlers
613
- # -----------------------
614
- def generate_response(message: str, history: List[List[str]], system_prompt: str,
615
- temperature: float, top_p: float, top_k: int, max_new_tokens: int,
616
- do_sample: bool, seed: Optional[int],
617
- rose_enable: bool, rose_alpha: float, rose_score: Optional[float],
618
- rose_tokens: str, rose_json: str,
619
- show_thinking: bool = False,
620
- reasoning_effort: str = "high") -> str:
621
- """Generate response with CoT handling."""
 
 
 
 
 
 
622
  try:
623
- # Build messages
624
- messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
625
 
626
- if history:
627
- for turn in history:
628
- if isinstance(turn, (list, tuple)) and len(turn) >= 2:
629
- user_msg, assistant_msg = turn[0], turn[1]
630
- if user_msg:
631
- messages.append({"role": "user", "content": str(user_msg)})
632
- if assistant_msg:
633
- messages.append({"role": "assistant", "content": str(assistant_msg)})
634
 
635
- messages.append({"role": "user", "content": str(message)})
 
636
 
637
  # Create prompt
638
- if HARMONY_AVAILABLE:
639
- prompt = create_harmony_prompt(messages, reasoning_effort)
640
- else:
641
- prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
642
-
643
- # Build Rose map
644
- rose_map: Optional[Dict[str, float]] = None
645
- if rose_enable:
646
- rose_map = {}
647
- tok_str = (rose_tokens or "").strip()
648
- if tok_str:
649
- for p in [p.strip() for p in tok_str.split(",") if p.strip()]:
650
- if ":" in p:
651
- k, v = p.split(":", 1)
652
- try:
653
- rose_map[k.strip()] = float(v)
654
- except:
655
- pass
656
- if rose_json:
657
- try:
658
- j = json.loads(rose_json)
659
- if isinstance(j, dict):
660
- for k, v in j.items():
661
- try:
662
- rose_map[str(k)] = float(v)
663
- except:
664
- pass
665
- except:
666
- pass
667
- if not rose_map:
668
- rose_map = None
669
-
670
  # Generate
671
- channels = zerogpu_generate(
672
  prompt,
673
- {
674
- "do_sample": bool(do_sample),
675
- "temperature": float(temperature),
676
- "top_p": float(top_p),
677
- "top_k": int(top_k) if top_k > 0 else None,
678
- "max_new_tokens": int(max_new_tokens),
679
- "repetition_penalty": 1.1,
680
- "no_repeat_ngram_size": 6,
681
- },
682
- rose_map,
683
- float(rose_alpha),
684
- float(rose_score) if rose_score is not None else None,
685
- int(seed) if seed is not None else None,
686
  )
687
 
688
  # Format response
689
- if show_thinking:
690
  response = "## Chain of Thought:\n\n"
691
  for channel, content in channels.items():
692
  if channel != "final" and content:
693
- response += f"### {channel.capitalize()} Channel:\n{content}\n\n"
694
- response += f"### Final Response:\n{channels.get('final', 'No final response generated')}"
695
- return response
696
  else:
697
- return channels.get("final", "No final response generated")
698
-
699
- except Exception as e:
700
- import traceback
701
- return f"[Error] {type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
702
-
703
- # -----------------------
704
- # UI
705
- # -----------------------
706
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
707
- gr.Markdown(
708
- f"""
709
- # Mirel – Harmony Chain-of-Thought Inference
710
 
711
- **Model**: {MODEL_ID} {'(MX Format)' if USE_MX_FORMAT else ''}
712
- **Adapter**: {ADAPTER_ID or 'None'}
713
- **Status**: {'βœ… Harmony Available' if HARMONY_AVAILABLE else '⚠️ Harmony Not Installed'}
714
 
715
- The model uses internal thinking channels before providing final responses.
716
- """
717
- )
718
 
719
- with gr.Row():
720
- system_prompt = gr.Textbox(
721
- label="System Prompt",
722
- value=SYSTEM_DEF,
723
- lines=2
724
- )
 
 
 
 
 
 
725
 
726
- with gr.Accordion("Generation Settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  with gr.Row():
728
  temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
729
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
730
- top_k = gr.Slider(0, 200, value=0, step=1, label="Top-k (0=disabled)")
 
731
  with gr.Row():
732
- max_new = gr.Slider(16, 4096, value=MAX_DEF, step=16, label="Max new tokens")
733
- do_sample = gr.Checkbox(value=True, label="Do sample")
734
  seed = gr.Number(value=None, label="Seed (optional)", precision=0)
 
735
  with gr.Row():
 
 
736
  reasoning_effort = gr.Radio(
737
- choices=["low", "medium", "high"],
738
  value="high",
739
- label="Reasoning Effort",
740
- info="How much thinking the model should do"
741
- )
742
- show_thinking = gr.Checkbox(
743
- value=False,
744
- label="Show thinking channels",
745
- info="Display all internal reasoning channels"
746
  )
747
 
748
- with gr.Accordion("Rose Guidance (Optional)", open=False):
749
- gr.Markdown("Fine-tune generation with token biases")
750
- with gr.Row():
751
- rose_enable = gr.Checkbox(value=False, label="Enable Rose bias")
752
- rose_alpha = gr.Slider(0.0, 5.0, value=1.0, step=0.05, label="Alpha (strength)")
753
- rose_score = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Score multiplier")
754
- rose_tokens = gr.Textbox(
755
- label="Token:weight pairs",
756
- placeholder="example:1.5, test:-0.5",
757
- value=""
758
- )
759
- rose_json = gr.Textbox(
760
- label="JSON weights",
761
- placeholder='{"token": 1.0, "another": -0.5}',
762
- value=""
763
- )
764
-
765
  # Chat interface
766
  chat = gr.ChatInterface(
767
- fn=generate_response,
768
- type="messages",
769
  additional_inputs=[
770
- system_prompt, temperature, top_p, top_k, max_new,
771
- do_sample, seed, rose_enable, rose_alpha, rose_score,
772
- rose_tokens, rose_json, show_thinking, reasoning_effort
 
 
 
 
 
 
 
773
  ],
774
- title="Chat with Mirel",
775
- description="Chain-of-thought model with MX format support",
776
  examples=[
777
  ["Hello! Can you introduce yourself?"],
778
- ["What is the capital of France?"],
779
- ["Explain quantum computing in simple terms"],
780
- ["Solve: If a train travels 120 miles in 2 hours, what is its average speed?"],
781
  ],
782
  cache_examples=False,
783
  )
784
-
785
- gr.Markdown(
786
- """
787
- ---
788
- ### Configuration:
789
- - **MX Format**: Automatically detected for GPT-OSS models
790
- - **LoRA Support**: fp32 LoRA adapters are converted for MX compatibility
791
- - **Merge Adapter**: Set `MERGE_ADAPTER=1` to merge LoRA into base model
792
- - **Auth**: Set `HF_TOKEN` in Space secrets for private model access
793
- """
794
- )
795
-
796
  if __name__ == "__main__":
797
- demo.queue(max_size=8 if ZEROGPU else 32).launch(
798
- server_name="0.0.0.0",
 
 
 
 
 
 
 
 
799
  server_port=7860,
800
  share=False
801
  )
 
3
  ZeroGPU-ready, Harmony formatting, MX format support for GPT-OSS-20B
4
  Proper LoRA adapter loading and conversion for MX compatibility
5
  Single file: app.py
6
+
7
+ Requirements:
8
+ huggingface_hub>=0.34.0
9
+ transformers>=4.55.0
10
+ accelerate>=0.33.0
11
+ peft>=0.11.0
12
+ torch>=2.4.0
13
+ bitsandbytes>=0.43.1
14
+ openai-harmony
15
+ gradio>=5.42.0
16
+ triton>=3.4.0
17
+ git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
18
  """
19
+
20
+ # ===== SETUP: Ensure triton_kernels is installed for MX format =====
21
+ import subprocess
22
+ import sys
23
+
24
+ def ensure_triton_kernels():
25
+ """Ensure triton_kernels is installed for MX format support on H200."""
26
+ try:
27
+ import triton_kernels
28
+ print("βœ“ triton_kernels already installed - MX format supported")
29
+ return True
30
+ except ImportError:
31
+ print("Installing triton_kernels for MX format support...")
32
+ try:
33
+ subprocess.check_call([
34
+ sys.executable, "-m", "pip", "install",
35
+ "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
36
+ ])
37
+ print("βœ“ triton_kernels installed successfully")
38
+ # Force reimport
39
+ import importlib
40
+ import site
41
+ importlib.reload(site)
42
+ return True
43
+ except subprocess.CalledProcessError as e:
44
+ print(f"βœ— Failed to install triton_kernels: {e}")
45
+ print("ERROR: MX format will NOT work properly without triton_kernels!")
46
+ return False
47
+
48
+ # Install triton_kernels before other imports
49
+ _TRITON_INSTALL_SUCCESS = ensure_triton_kernels()
50
+
51
+ # ===== MAIN IMPORTS =====
52
  from __future__ import annotations
53
+ import os, gc, json, torch, warnings, traceback
54
  from dataclasses import dataclass
55
  from typing import List, Dict, Optional, Any, Union
56
  from datetime import datetime
 
59
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
60
  import numpy as np
61
 
62
+ # Suppress warnings
63
  warnings.filterwarnings("ignore", message=".*microscaling.*")
64
  warnings.filterwarnings("ignore", message=".*mx.*")
65
 
 
77
  ReasoningEffort
78
  )
79
  HARMONY_AVAILABLE = True
80
+ print("βœ“ OpenAI Harmony loaded successfully")
81
  except ImportError:
82
+ print("⚠ openai_harmony not installed. Install with: pip install openai-harmony")
83
  HARMONY_AVAILABLE = False
84
 
85
+ # Import PEFT for LoRA support
86
+ try:
87
+ from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
88
+ _HAS_PEFT = True
89
+ print("βœ“ PEFT loaded successfully")
90
+ except Exception:
91
+ _HAS_PEFT = False
92
+ print("⚠ PEFT not available. Install with: pip install peft")
93
+
94
+ # Check for triton_kernels (required for MX format)
95
+ try:
96
+ import triton_kernels
97
+ _HAS_TRITON_KERNELS = True
98
+ print("βœ“ triton_kernels loaded - MX format enabled")
99
+ except ImportError:
100
+ _HAS_TRITON_KERNELS = False
101
+ print("βœ— triton_kernels not available - MX format disabled!")
102
+
103
+ # ===== CONFIGURATION =====
104
  MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
105
+ ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b")
106
+ ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516")
107
  ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
108
+ SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
109
+ MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
110
+ ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "1")) == "1"
111
+ MERGE_ADAPTER = os.getenv("MERGE_ADAPTER", "0") == "1"
112
 
113
+ # Detect if using GPT-OSS model
114
  IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower()
115
+ USE_MX_FORMAT = IS_GPT_OSS and _HAS_TRITON_KERNELS
116
 
117
+ # Harmony channels for chain-of-thought
118
  REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
119
 
120
+ # HF Authentication
121
+ HF_TOKEN = (
122
  os.getenv("HF_TOKEN")
123
  or os.getenv("HUGGING_FACE_HUB_TOKEN")
124
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
125
  or os.getenv("HF_ACCESS_TOKEN")
126
  )
127
 
128
+ def _hf_login():
129
+ """Login to HuggingFace Hub."""
130
  if HF_TOKEN:
131
  try:
132
  from huggingface_hub import login, whoami
133
  login(token=HF_TOKEN, add_to_git_credential=True)
134
  try:
135
+ user = whoami(token=HF_TOKEN)
136
+ print(f"βœ“ Logged in as: {user.get('name', user.get('id', 'unknown'))}")
137
+ except:
138
+ print("βœ“ HF login successful")
139
  except Exception as e:
140
+ print(f"⚠ HF login failed: {e}")
141
  else:
142
+ print("⚠ No HF_TOKEN found in environment")
143
 
144
+ # Login before loading models
145
  _hf_login()
146
 
147
+ # Disable tokenizer parallelism warning
148
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
149
 
150
+ # ===== LOAD TOKENIZER =====
 
 
 
 
 
 
 
 
 
151
  try:
152
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
153
+ print(f"βœ“ Tokenizer loaded from {MODEL_ID}")
154
  except Exception as e:
155
+ print(f"βœ— Failed to load tokenizer: {e}")
156
  raise
157
 
158
+ # ===== HARMONY SETUP =====
159
+ if HARMONY_AVAILABLE:
160
+ harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
161
+ HARMONY_STOP_IDS = harmony_encoding.stop_tokens_for_assistant_actions()
162
+ else:
163
+ harmony_encoding = None
164
+ HARMONY_STOP_IDS = []
 
 
165
 
166
+ # ===== MODEL LOADING WITH MX FORMAT SUPPORT =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ def detect_mx_format(model) -> bool:
169
+ """Check if model is using native MX format."""
170
+ if not hasattr(model, 'model') or not hasattr(model.model, 'layers'):
171
+ return False
 
 
 
172
 
173
+ try:
174
+ first_layer = model.model.layers[0]
175
+ if hasattr(first_layer, 'block_sparse_moe'):
176
+ expert = first_layer.block_sparse_moe.experts[0]
177
+ if hasattr(expert, 'w1'):
178
+ # Check for MX format scale tensors
179
+ return hasattr(expert.w1, 'scales')
180
+ except:
181
+ pass
182
+ return False
183
+
184
+ def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
185
+ """Load the base model with proper MX format handling."""
186
+ print(f"\n{'='*50}")
187
+ print(f"Loading model: {MODEL_ID}")
188
+ print(f"MX Format Available: {_HAS_TRITON_KERNELS}")
189
+ print(f"{'='*50}\n")
190
 
191
+ # Load config to check model type
192
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
193
 
194
+ # Build loading kwargs
195
+ load_kwargs = {
196
+ "trust_remote_code": True,
197
+ "device_map": device_map,
198
+ "low_cpu_mem_usage": True,
199
+ "token": HF_TOKEN,
200
+ "attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
201
+ }
202
 
203
+ if IS_GPT_OSS:
204
+ if _HAS_TRITON_KERNELS:
205
+ print("β†’ Loading with native MX format support")
206
+ load_kwargs["torch_dtype"] = "auto" # Let model use native MX
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  else:
208
+ print("⚠ No triton_kernels - falling back to bf16 (dequantized)")
209
+ print(" This will likely cause LoRA compatibility issues!")
210
+ load_kwargs["torch_dtype"] = torch.bfloat16
211
+ else:
212
+ # Non-GPT-OSS models
213
+ load_kwargs["torch_dtype"] = torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ # Load the model
216
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
217
 
218
+ # Verify format
219
+ print(f"Model loaded - dtype: {next(model.parameters()).dtype}")
220
+ if IS_GPT_OSS:
221
+ is_mx = detect_mx_format(model)
222
+ if is_mx:
223
+ print("βœ“ Confirmed: Using native MX format")
224
+ else:
225
+ print("⚠ Model dequantized to bf16 - LoRA may fail")
226
 
227
+ # Set model config
228
+ if getattr(model.config, "pad_token_id", None) is None:
229
+ model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
230
+ model.config.use_cache = True
231
 
 
232
  return model
233
 
234
+ def load_lora_adapter(model, adapter_id: str, subfolder: Optional[str] = None):
235
+ """Load and attach LoRA adapter with MX format handling."""
236
+ if not _HAS_PEFT:
237
+ raise RuntimeError("PEFT is required for LoRA adapters")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
+ print(f"\n{'='*50}")
240
+ print(f"Loading LoRA: {adapter_id}")
241
+ if subfolder:
242
+ print(f"Subfolder: {subfolder}")
243
+ print(f"{'='*50}\n")
244
 
245
+ # Check if model is using MX format
246
+ is_mx = detect_mx_format(model) if IS_GPT_OSS else False
247
 
248
+ # Prepare kwargs for PEFT
249
+ peft_kwargs = {"token": HF_TOKEN, "is_trainable": False}
250
+ if subfolder:
251
+ peft_kwargs["subfolder"] = subfolder
 
 
252
 
253
+ try:
254
+ # Load adapter configuration
255
+ peft_config = PeftConfig.from_pretrained(adapter_id, **peft_kwargs)
256
+ print(f"LoRA config: r={peft_config.r}, alpha={peft_config.lora_alpha}")
257
 
258
+ # Load the adapter
259
+ model = PeftModel.from_pretrained(model, adapter_id, **peft_kwargs)
 
 
 
 
 
 
 
 
 
 
260
 
261
+ if not is_mx and IS_GPT_OSS:
262
+ print("⚠ WARNING: Model is bf16 but LoRA was likely trained on MX format")
263
+ print(" Reducing LoRA influence to 10% to prevent corruption")
264
+
265
+ # Scale down LoRA weights
266
+ for name, param in model.named_parameters():
267
+ if 'lora_' in name:
268
+ param.data *= 0.1
269
 
270
+ print("βœ“ LoRA adapter loaded successfully")
271
+
272
+ # Optionally merge adapter
273
+ if MERGE_ADAPTER and hasattr(model, 'merge_and_unload'):
274
+ print("Merging adapter into base model...")
275
+ model = model.merge_and_unload()
276
+ print("βœ“ Adapter merged")
277
+
278
+ return model
279
+
280
+ except Exception as e:
281
+ print(f"βœ— Failed to load LoRA: {e}")
282
+ print("Continuing with base model only")
283
+ return model
284
+
285
+ # ===== HARMONY FORMATTING =====
286
+
287
+ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high"):
288
+ """Create Harmony-formatted prompt."""
289
+ if not HARMONY_AVAILABLE or not harmony_encoding:
290
+ # Fallback to chat template
291
+ if messages and messages[0].get("role") != "system":
292
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages
293
+ return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
294
 
295
+ # Map reasoning effort
296
+ effort_map = {
297
+ "low": ReasoningEffort.LOW,
298
+ "medium": ReasoningEffort.MEDIUM,
299
+ "high": ReasoningEffort.HIGH
300
+ }
301
+ effort = effort_map.get(reasoning_effort.lower(), ReasoningEffort.HIGH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
+ # Build Harmony conversation
304
+ system_content = (
305
+ SystemContent.new()
306
+ .with_model_identity("You are ChatGPT, a large language model trained by OpenAI.")
307
+ .with_reasoning_effort(effort)
308
+ .with_conversation_start_date(datetime.now().strftime("%Y-%m-%d"))
309
+ .with_knowledge_cutoff("2024-06")
310
+ .with_required_channels(REQUIRED_CHANNELS)
311
+ )
312
 
313
+ # Extract system prompt
314
+ sys_text = SYSTEM_PROMPT
315
+ rest = messages or []
316
+ if rest and rest[0].get("role") == "system":
317
+ sys_text = rest[0].get("content", SYSTEM_PROMPT)
318
+ rest = rest[1:]
319
 
320
+ # Build messages
321
+ harmony_messages = [
322
+ Message.from_role_and_content(Role.SYSTEM, system_content),
323
+ Message.from_role_and_content(
324
+ Role.DEVELOPER,
325
+ DeveloperContent.new().with_instructions(sys_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  )
327
+ ]
328
+
329
+ for msg in rest:
330
+ role = msg.get("role")
331
+ content = msg.get("content", "")
332
+ if role == "user":
333
+ harmony_messages.append(Message.from_role_and_content(Role.USER, content))
334
+ elif role == "assistant":
335
+ harmony_messages.append(
336
+ Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final")
337
+ )
338
+
339
+ # Render to token IDs
340
+ convo = Conversation.from_messages(harmony_messages)
341
+ return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
344
+ """Parse Harmony response tokens into channels."""
345
+ if not HARMONY_AVAILABLE or not harmony_encoding:
346
  text = tokenizer.decode(tokens, skip_special_tokens=False)
347
+ return {"final": extract_final_channel(text), "raw": text}
 
 
 
 
 
 
 
 
 
 
 
 
348
 
 
 
 
 
349
  try:
350
+ # Parse using Harmony
351
+ parsed = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
352
+
353
+ channels = {}
354
+ for msg in parsed:
355
+ channel = getattr(msg, 'channel', 'final')
356
+ if channel not in channels:
357
+ channels[channel] = ""
358
+
359
+ # Extract text content
360
+ content = msg.content
361
+ if isinstance(content, list):
362
+ text = "".join([getattr(part, "text", str(part)) for part in content])
363
+ else:
364
+ text = getattr(content, "text", str(content))
365
+
366
+ channels[channel] += text
367
+
368
+ # Ensure final channel exists
369
+ if "final" not in channels:
370
+ channels["final"] = " ".join(channels.values())
371
+
372
+ return channels
373
+
374
+ except Exception as e:
375
+ print(f"Harmony parsing failed: {e}")
376
+ text = tokenizer.decode(tokens, skip_special_tokens=False)
377
+ return {"final": extract_final_channel(text), "raw": text}
378
+
379
+ def extract_final_channel(text: str) -> str:
380
+ """Extract final channel from raw text."""
381
+ # Look for <|channel|>final<|message|>
382
+ if "<|channel|>final<|message|>" in text:
383
+ parts = text.split("<|channel|>final<|message|>")
384
+ if len(parts) > 1:
385
+ final = parts[-1]
386
+ # Truncate at next marker
387
+ for marker in ["<|channel|>", "<|end|>", "<|return|>"]:
388
+ if marker in final:
389
+ final = final.split(marker)[0]
390
+ return final.strip()
391
+
392
+ # Fallback: return cleaned text
393
+ for marker in ["<|channel|>", "<|message|>", "<|end|>", "<|return|>"]:
394
+ text = text.replace(marker, " ")
395
  return text.strip()
396
 
397
+ # ===== GENERATION =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
 
 
 
399
  @spaces.GPU(duration=120)
400
+ def generate_on_gpu(
401
+ prompt,
402
+ temperature: float,
403
+ top_p: float,
404
+ top_k: int,
405
+ max_new_tokens: int,
406
+ do_sample: bool,
407
+ repetition_penalty: float,
408
+ seed: Optional[int]
409
+ ) -> Dict[str, str]:
410
+ """Run generation on GPU."""
411
  try:
412
+ # Set seed if provided
413
  if seed is not None:
414
  torch.manual_seed(int(seed))
 
 
 
415
 
416
+ # Load model
417
+ print("\nLoading model for generation...")
418
+ model = load_base_model("auto")
419
+
420
+ # Load LoRA if specified
421
+ if ADAPTER_ID:
422
+ model = load_lora_adapter(model, ADAPTER_ID, ADAPTER_SUBFOLDER)
423
+
424
+ model.eval()
425
+
426
  # Prepare inputs
427
  device = next(model.parameters()).device
428
+
429
+ if HARMONY_AVAILABLE and isinstance(prompt, list):
430
+ # Harmony returns token IDs
431
+ input_ids = torch.tensor([prompt], dtype=torch.long, device=device)
 
432
  else:
433
+ # String prompt
434
+ inputs = tokenizer(prompt, return_tensors="pt")
435
+ input_ids = inputs["input_ids"].to(device)
436
+
437
+ attention_mask = torch.ones_like(input_ids)
438
+ prompt_len = input_ids.shape[1]
439
 
440
  # Generate
441
+ print("Generating response...")
442
+ with torch.no_grad():
443
+ outputs = model.generate(
444
+ input_ids=input_ids,
445
+ attention_mask=attention_mask,
446
+ max_new_tokens=max_new_tokens,
447
+ temperature=temperature,
448
+ top_p=top_p,
449
+ top_k=top_k if top_k > 0 else None,
450
+ do_sample=do_sample,
451
+ repetition_penalty=repetition_penalty,
452
+ pad_token_id=model.config.pad_token_id,
453
+ eos_token_id=HARMONY_STOP_IDS if HARMONY_STOP_IDS else tokenizer.eos_token_id,
454
+ no_repeat_ngram_size=3,
455
+ )
 
456
 
457
  # Extract generated tokens
458
+ gen_tokens = outputs[0][prompt_len:].tolist()
 
459
 
460
  # Truncate at stop tokens
461
+ for stop_id in HARMONY_STOP_IDS:
462
+ if stop_id in gen_tokens:
463
+ gen_tokens = gen_tokens[:gen_tokens.index(stop_id)]
464
+ break
 
465
 
466
  # Parse response
467
+ channels = parse_harmony_response(gen_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
  return channels
470
+
471
  except Exception as e:
472
+ error_msg = f"Generation failed: {str(e)}\n{traceback.format_exc()}"
473
+ print(error_msg)
474
+ return {"final": f"Error: {str(e)}", "raw": error_msg}
475
+
476
  finally:
477
  # Cleanup
478
+ if 'model' in locals():
479
  del model
 
 
480
  gc.collect()
481
  if torch.cuda.is_available():
482
  torch.cuda.empty_cache()
483
 
484
+ # ===== GRADIO INTERFACE =====
485
+
486
+ def chat_response(
487
+ message: str,
488
+ history: List[List[str]],
489
+ system_prompt: str,
490
+ temperature: float,
491
+ top_p: float,
492
+ top_k: int,
493
+ max_new_tokens: int,
494
+ do_sample: bool,
495
+ repetition_penalty: float,
496
+ seed: Optional[int],
497
+ reasoning_effort: str,
498
+ show_thinking: bool
499
+ ) -> str:
500
+ """Handle chat interaction."""
501
  try:
502
+ # Build conversation
503
+ messages = [{"role": "system", "content": system_prompt or SYSTEM_PROMPT}]
504
 
505
+ # Add history
506
+ for turn in history or []:
507
+ if isinstance(turn, (list, tuple)) and len(turn) >= 2:
508
+ user_msg, assistant_msg = turn[0], turn[1]
509
+ if user_msg:
510
+ messages.append({"role": "user", "content": str(user_msg)})
511
+ if assistant_msg:
512
+ messages.append({"role": "assistant", "content": str(assistant_msg)})
513
 
514
+ # Add current message
515
+ messages.append({"role": "user", "content": message})
516
 
517
  # Create prompt
518
+ prompt = create_harmony_prompt(messages, reasoning_effort)
519
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  # Generate
521
+ channels = generate_on_gpu(
522
  prompt,
523
+ temperature,
524
+ top_p,
525
+ top_k,
526
+ max_new_tokens,
527
+ do_sample,
528
+ repetition_penalty,
529
+ seed
 
 
 
 
 
 
530
  )
531
 
532
  # Format response
533
+ if show_thinking and len(channels) > 1:
534
  response = "## Chain of Thought:\n\n"
535
  for channel, content in channels.items():
536
  if channel != "final" and content:
537
+ response += f"### {channel.capitalize()}:\n{content}\n\n"
538
+ response += f"### Final Response:\n{channels.get('final', 'No response generated')}"
 
539
  else:
540
+ response = channels.get("final", "No response generated")
 
 
 
 
 
 
 
 
 
 
 
 
541
 
542
+ return response
 
 
543
 
544
+ except Exception as e:
545
+ return f"Error: {str(e)}"
 
546
 
547
+ # ===== BUILD UI =====
548
+
549
+ with gr.Blocks(theme=gr.themes.Soft(), title="Mirel") as demo:
550
+ # Header with status
551
+ status_mx = "βœ… MX Format" if _HAS_TRITON_KERNELS else "❌ No MX Support"
552
+ status_harmony = "βœ… Harmony" if HARMONY_AVAILABLE else "❌ No Harmony"
553
+
554
+ gr.Markdown(f"""
555
+ # πŸ€– Mirel – Chain-of-Thought Assistant
556
+
557
+ **Model:** `{MODEL_ID}` | **Adapter:** `{ADAPTER_ID or 'None'}`
558
+ **Status:** {status_mx} | {status_harmony} | {"βœ… ZeroGPU" if ZEROGPU else "CPU Mode"}
559
 
560
+ {'''
561
+ ⚠️ **WARNING: MX Format Support Missing!**
562
+ Install with: `pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels`
563
+ ''' if IS_GPT_OSS and not _HAS_TRITON_KERNELS else ''}
564
+ """)
565
+
566
+ # System prompt
567
+ system_prompt = gr.Textbox(
568
+ label="System Prompt",
569
+ value=SYSTEM_PROMPT,
570
+ lines=2
571
+ )
572
+
573
+ # Settings
574
+ with gr.Accordion("βš™οΈ Generation Settings", open=False):
575
  with gr.Row():
576
  temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
577
+ top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
578
+ top_k = gr.Slider(0, 200, value=50, step=1, label="Top-k")
579
+
580
  with gr.Row():
581
+ max_new_tokens = gr.Slider(16, 2048, value=MAX_NEW_TOKENS, step=16, label="Max tokens")
582
+ repetition_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label="Repetition penalty")
583
  seed = gr.Number(value=None, label="Seed (optional)", precision=0)
584
+
585
  with gr.Row():
586
+ do_sample = gr.Checkbox(value=True, label="Sample")
587
+ show_thinking = gr.Checkbox(value=False, label="Show thinking channels")
588
  reasoning_effort = gr.Radio(
589
+ ["low", "medium", "high"],
590
  value="high",
591
+ label="Reasoning effort"
 
 
 
 
 
 
592
  )
593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  # Chat interface
595
  chat = gr.ChatInterface(
596
+ fn=chat_response,
 
597
  additional_inputs=[
598
+ system_prompt,
599
+ temperature,
600
+ top_p,
601
+ top_k,
602
+ max_new_tokens,
603
+ do_sample,
604
+ repetition_penalty,
605
+ seed,
606
+ reasoning_effort,
607
+ show_thinking
608
  ],
609
+ title=None,
 
610
  examples=[
611
  ["Hello! Can you introduce yourself?"],
612
+ ["What's the capital of France?"],
613
+ ["Explain quantum computing simply"],
614
+ ["Write a haiku about coding"],
615
  ],
616
  cache_examples=False,
617
  )
618
+
619
+ # Footer
620
+ gr.Markdown("""
621
+ ---
622
+ πŸ’‘ **Tips:**
623
+ - Enable "Show thinking channels" to see the model's reasoning process
624
+ - Adjust "Reasoning effort" for faster responses (low) or better quality (high)
625
+ - The model uses MX format on H200 GPUs for optimal performance
626
+ """)
627
+
628
+ # ===== LAUNCH =====
 
629
  if __name__ == "__main__":
630
+ print("\n" + "="*60)
631
+ print("MIREL READY TO LAUNCH")
632
+ print(f"Model: {MODEL_ID}")
633
+ print(f"Adapter: {ADAPTER_ID or 'None'}")
634
+ print(f"MX Format: {'ENABLED' if _HAS_TRITON_KERNELS else 'DISABLED'}")
635
+ print(f"Harmony: {'ENABLED' if HARMONY_AVAILABLE else 'DISABLED'}")
636
+ print("="*60 + "\n")
637
+
638
+ demo.queue(max_size=10).launch(
639
+ server_name="0.0.0.0",
640
  server_port=7860,
641
  share=False
642
  )
install.sh ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Complete installation script for Mirel with MX format support on H200
3
+
4
+ echo "Installing Mirel dependencies for GPT-OSS with MX format support..."
5
+
6
+ # Upgrade pip first
7
+ pip install --upgrade pip
8
+
9
+ # Install main requirements
10
+ pip install huggingface_hub>=0.34.0
11
+ pip install transformers>=4.55.0
12
+ pip install accelerate>=0.33.0
13
+ pip install torch>=2.4.0
14
+ pip install gradio>=5.42.0
15
+ pip install spaces
16
+
17
+ # Install LoRA/PEFT support
18
+ pip install peft>=0.11.0
19
+ pip install bitsandbytes>=0.43.1
20
+
21
+ # Install Harmony format
22
+ pip install openai-harmony
23
+
24
+ # Install Triton and MX format support
25
+ pip install triton>=3.4.0
26
+
27
+ # CRITICAL: Install triton_kernels from git subdirectory
28
+ # This is REQUIRED for MX format on H200 GPUs
29
+ echo "Installing triton_kernels (REQUIRED for MX format)..."
30
+ pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
31
+
32
+ # Optional but recommended
33
+ pip install safetensors>=0.4.0
34
+ pip install sentencepiece>=0.2.0
35
+ pip install protobuf>=3.20.0
36
+ pip install "numpy<2.0.0"
37
+
38
+ # Verify critical imports
39
+ echo "Verifying installation..."
40
+ python -c "
41
+ import sys
42
+ errors = []
43
+
44
+ try:
45
+ import torch
46
+ print(f'βœ“ PyTorch {torch.__version__}')
47
+ except ImportError as e:
48
+ errors.append(f'βœ— PyTorch: {e}')
49
+
50
+ try:
51
+ import transformers
52
+ print(f'βœ“ Transformers {transformers.__version__}')
53
+ except ImportError as e:
54
+ errors.append(f'βœ— Transformers: {e}')
55
+
56
+ try:
57
+ import peft
58
+ print(f'βœ“ PEFT {peft.__version__}')
59
+ except ImportError as e:
60
+ errors.append(f'βœ— PEFT: {e}')
61
+
62
+ try:
63
+ import triton
64
+ print(f'βœ“ Triton {triton.__version__}')
65
+ except ImportError as e:
66
+ errors.append(f'βœ— Triton: {e}')
67
+
68
+ try:
69
+ import triton_kernels
70
+ print('βœ“ Triton Kernels (MX format support)')
71
+ except ImportError as e:
72
+ errors.append(f'βœ— Triton Kernels (CRITICAL): {e}')
73
+ print('⚠️ WARNING: MX format will NOT work without triton_kernels!')
74
+
75
+ try:
76
+ import openai_harmony
77
+ print('βœ“ OpenAI Harmony')
78
+ except ImportError as e:
79
+ errors.append(f'βœ— OpenAI Harmony: {e}')
80
+
81
+ try:
82
+ import gradio
83
+ print(f'βœ“ Gradio {gradio.__version__}')
84
+ except ImportError as e:
85
+ errors.append(f'βœ— Gradio: {e}')
86
+
87
+ if errors:
88
+ print('\n❌ Installation issues found:')
89
+ for error in errors:
90
+ print(f' {error}')
91
+ sys.exit(1)
92
+ else:
93
+ print('\nβœ… All dependencies installed successfully!')
94
+ print('You can now run the Mirel app with MX format support on H200 GPUs')
95
+ "
96
+
97
+ echo "Installation complete!"
requirements.txt CHANGED
@@ -1,10 +1,25 @@
 
1
  huggingface_hub>=0.34.0
2
  transformers>=4.55.0
3
  accelerate>=0.33.0
 
 
 
 
 
4
  peft>=0.11.0
5
- torch>=2.4.0 # ZeroGPU-supported (2.3.x is NOT supported)
6
  bitsandbytes>=0.43.1
7
- openai_harmony
8
- gradio>=5.42.0
 
 
 
9
  triton>=3.4.0
10
- msamp
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
  huggingface_hub>=0.34.0
3
  transformers>=4.55.0
4
  accelerate>=0.33.0
5
+ torch>=2.4.0
6
+ gradio>=5.42.0
7
+ spaces
8
+
9
+ # LoRA/PEFT support
10
  peft>=0.11.0
 
11
  bitsandbytes>=0.43.1
12
+
13
+ # Harmony format for OpenAI GPT-OSS models
14
+ openai-harmony
15
+
16
+ # MX format support (REQUIRED for GPT-OSS-20B on H200)
17
  triton>=3.4.0
18
+ # Note: triton_kernels must be installed separately from git:
19
+ # pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
20
+
21
+ # Optional but recommended
22
+ safetensors>=0.4.0
23
+ sentencepiece>=0.2.0
24
+ protobuf>=3.20.0
25
+ numpy<2.0.0 # Some dependencies may not support numpy 2.x yet
setup.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ setup.py - Run this at the start of app.py to ensure triton_kernels is installed
3
+ Add this to the top of your app.py file in HF Spaces
4
+ """
5
+
6
+ import subprocess
7
+ import sys
8
+
9
+ def ensure_triton_kernels():
10
+ """Ensure triton_kernels is installed for MX format support."""
11
+ try:
12
+ import triton_kernels
13
+ print("βœ“ triton_kernels already installed")
14
+ return True
15
+ except ImportError:
16
+ print("Installing triton_kernels for MX format support...")
17
+ try:
18
+ subprocess.check_call([
19
+ sys.executable, "-m", "pip", "install",
20
+ "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
21
+ ])
22
+ print("βœ“ triton_kernels installed successfully")
23
+ return True
24
+ except subprocess.CalledProcessError as e:
25
+ print(f"βœ— Failed to install triton_kernels: {e}")
26
+ print("WARNING: MX format will fall back to bf16, LoRA may not work!")
27
+ return False
28
+
29
+ # Run at import time
30
+ if __name__ != "__main__": # When imported
31
+ ensure_triton_kernels()