AbstractPhil commited on
Commit
5c9afc5
·
1 Parent(s): a2f6c58

claude helps again

Browse files
Files changed (1) hide show
  1. app.py +249 -89
app.py CHANGED
@@ -1,17 +1,22 @@
1
  """
2
  Mirel Harmony Inference – HF Space (Gradio)
3
- ZeroGPU-ready, Harmony formatting, optional Rose-guided decoding
4
- Chain-of-thought model with proper channel extraction using openai_harmony
5
  Single file: app.py
6
  """
7
  from __future__ import annotations
8
- import os, gc, json, threading, torch
9
  from dataclasses import dataclass
10
- from typing import List, Dict, Optional, Any
11
  from datetime import datetime
12
  import gradio as gr
13
  import spaces # required for ZeroGPU
14
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
15
 
16
  # Import Harmony components
17
  try:
@@ -34,22 +39,23 @@ except ImportError:
34
  # -----------------------
35
  # Config & runtime modes
36
  # -----------------------
37
- DTYPE_MAP = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
38
-
39
  MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
40
  ADAPTER_ID = os.getenv("ADAPTER_ID") or None
41
  ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER") or None
42
  ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
43
- DTYPE = DTYPE_MAP.get(os.getenv("DTYPE", "bf16").lower(), torch.bfloat16)
44
  SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
45
  MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
46
  ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
47
- LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
 
 
 
48
 
49
  # Harmony channels for CoT
50
  REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
51
 
52
- # HF Auth - properly handle multiple token env var names
53
  HF_TOKEN: Optional[str] = (
54
  os.getenv("HF_TOKEN")
55
  or os.getenv("HUGGING_FACE_HUB_TOKEN")
@@ -96,65 +102,231 @@ except Exception as e:
96
  raise
97
 
98
  # -----------------------
99
- # Model loading
100
  # -----------------------
101
  try:
102
- from peft import PeftModel
103
  _HAS_PEFT = True
104
  except Exception:
105
  _HAS_PEFT = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
 
 
 
108
  def _build_model_kwargs(device_map: Optional[str]) -> Dict[str, Any]:
 
109
  kw: Dict[str, Any] = dict(
110
- torch_dtype=DTYPE,
111
  device_map=device_map,
112
- attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager",
113
  trust_remote_code=True,
114
  low_cpu_mem_usage=True,
115
  token=HF_TOKEN,
116
  )
117
- if LOAD_4BIT and device_map != "cpu":
118
- try:
119
- import bitsandbytes as _bnb
120
- kw.update(load_in_4bit=True)
121
- if kw["device_map"] is None:
122
- kw["device_map"] = "auto"
123
- except Exception:
124
- pass
 
 
 
 
 
 
 
 
 
125
  return kw
126
 
127
-
128
  def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
 
129
  print(f"[Model] Loading base model from {MODEL_ID}...")
130
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **_build_model_kwargs(device_map))
131
 
132
- #if ADAPTER_ID:
133
- # if not _HAS_PEFT:
134
- # raise RuntimeError("peft is required when ADAPTER_ID is set.")
135
- # print(f"[Model] Loading adapter from {ADAPTER_ID}...")
136
- # peft_kwargs: Dict[str, Any] = {"token": HF_TOKEN}
137
- # if ADAPTER_SUBFOLDER:
138
- # peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
139
- # model = PeftModel.from_pretrained(model, ADAPTER_ID, is_trainable=False, **peft_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  model.eval()
142
- # Ensure a valid pad_token_id is set; some OSS checkpoints reuse eos as pad
 
143
  if getattr(model.config, "pad_token_id", None) is None:
144
  model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
145
  model.config.use_cache = True
146
- print("[Model] Model loaded successfully")
 
147
  return model
148
 
149
  # -----------------------
150
  # Harmony formatting
151
  # -----------------------
152
-
153
  def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
154
- """Build a Harmony-formatted prompt. If Harmony is available, return **token IDs**
155
- rendered by `openai_harmony` (authoritative). Otherwise fall back to the
156
- tokenizer's chat template and return a string.
157
- """
158
  if HARMONY_AVAILABLE and harmony_encoding is not None:
159
  effort_map = {"low": ReasoningEffort.LOW, "medium": ReasoningEffort.MEDIUM, "high": ReasoningEffort.HIGH}
160
  effort = effort_map.get(str(reasoning_effort).lower(), ReasoningEffort.HIGH)
@@ -168,7 +340,6 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
168
  .with_required_channels(REQUIRED_CHANNELS)
169
  )
170
 
171
- # Use first system message as developer instructions if present, else SYSTEM_DEF
172
  sys_text = SYSTEM_DEF
173
  rest: List[Dict[str, str]] = messages or []
174
  if rest and rest[0].get("role") == "system":
@@ -191,7 +362,7 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
191
  convo = Conversation.from_messages(harmony_messages)
192
  return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
193
 
194
- # Fallback: tokenizer chat template -> string prompt
195
  if not messages or messages[0].get("role") != "system":
196
  messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
197
  return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
@@ -199,14 +370,11 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
199
  def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
200
  """Parse response tokens using Harmony format to extract channels."""
201
  if not HARMONY_AVAILABLE:
202
- # Fallback: just decode and extract final channel manually
203
  text = tokenizer.decode(tokens, skip_special_tokens=False)
204
  return {"final": extract_final_channel_fallback(text), "raw": text}
205
 
206
- # Parse messages from completion tokens
207
  parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
208
 
209
- # Extract content by channel
210
  channels = {}
211
  for msg in parsed_messages:
212
  channel = msg.channel if hasattr(msg, 'channel') else "final"
@@ -214,16 +382,13 @@ def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
214
  channels[channel] = ""
215
  channels[channel] += "".join([getattr(part, "text", str(part)) for part in (msg.content if isinstance(msg.content, list) else [msg.content])])
216
 
217
- # Ensure we have a final channel
218
  if "final" not in channels:
219
  channels["final"] = " ".join(channels.values())
220
 
221
  return channels
222
 
223
  def extract_final_channel_fallback(text: str) -> str:
224
- """Robustly extract the <final> channel from decoded Harmony text.
225
- Works even if parsing fails or the model emits extra headers.
226
- """
227
  try:
228
  chunks: Dict[str, str] = {}
229
  pieces = text.split("<|channel|>")
@@ -233,7 +398,6 @@ def extract_final_channel_fallback(text: str) -> str:
233
  continue
234
  ch = seg[:name_end].strip()
235
  body_start = name_end + len("<|message|>")
236
- # end at next channel/end/return marker
237
  next_pos = len(seg)
238
  for delim in ("<|channel|>", "<|end|>", "<|return|>"):
239
  p = seg.find(delim, body_start)
@@ -244,7 +408,6 @@ def extract_final_channel_fallback(text: str) -> str:
244
  final_txt = (chunks.get("final", "").strip())
245
  if final_txt:
246
  return final_txt
247
- # Fallback: everything after last final marker up to a terminator
248
  if "<|channel|>final<|message|>" in text:
249
  tail = text.split("<|channel|>final<|message|>")[-1]
250
  for delim in ("<|return|>", "<|end|>", "<|channel|>"):
@@ -260,7 +423,6 @@ def extract_final_channel_fallback(text: str) -> str:
260
  # -----------------------
261
  # Rose guidance
262
  # -----------------------
263
-
264
  def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
265
  """Create vocab bias from {token: weight}."""
266
  vocab_size = len(tokenizer)
@@ -286,6 +448,9 @@ class RoseGuidedLogits(torch.nn.Module):
286
  def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
287
  return scores + self.alpha * self.bias_vec.to(scores.device)
288
 
 
 
 
289
  @spaces.GPU(duration=120)
290
  def zerogpu_generate(full_prompt,
291
  gen_kwargs: Dict[str, Any],
@@ -293,12 +458,12 @@ def zerogpu_generate(full_prompt,
293
  rose_alpha: float,
294
  rose_score: Optional[float],
295
  seed: Optional[int]) -> Dict[str, str]:
296
- """Run inference on GPU and return parsed channels."""
297
  try:
298
  if seed is not None:
299
  torch.manual_seed(int(seed))
300
 
301
- # Load model
302
  model = _load_model_on("auto")
303
 
304
  # Setup logits processor for Rose guidance
@@ -308,7 +473,7 @@ def zerogpu_generate(full_prompt,
308
  eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
309
  logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
310
 
311
- # Tokenize / prepare inputs
312
  device = next(model.parameters()).device
313
  if HARMONY_AVAILABLE and isinstance(full_prompt, list):
314
  input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
@@ -319,11 +484,10 @@ def zerogpu_generate(full_prompt,
319
  enc = tokenizer(full_prompt, return_tensors="pt")
320
  inputs = enc.to(device)
321
  prompt_len = int(inputs["input_ids"].shape[1])
322
- # Guarantee attention_mask exists; avoids pad==eos ambiguity warnings
323
  if "attention_mask" not in inputs:
324
  inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
 
325
  # Generate
326
- # Build EOS list: use ONLY Harmony assistant-action stops (per OpenAI docs)
327
  eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
328
 
329
  out_ids = model.generate(
@@ -341,29 +505,28 @@ def zerogpu_generate(full_prompt,
341
  min_new_tokens=1,
342
  )
343
 
344
- # Extract generated tokens only
345
  out_list = out_ids[0].tolist()
346
  gen_ids = out_list[prompt_len:]
347
- # Truncate at first Harmony stop token if present
 
348
  if HARMONY_AVAILABLE:
349
  for sid in HARMONY_STOP_IDS:
350
  if sid in gen_ids:
351
  gen_ids = gen_ids[:gen_ids.index(sid)]
352
  break
353
 
354
- # Parse response with Harmony
355
  if HARMONY_AVAILABLE:
356
  try:
357
  channels = parse_harmony_response(gen_ids)
358
  except Exception:
359
- # Fallback to text parsing if Harmony parser fails
360
  decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
361
  channels = {
362
  "final": extract_final_channel_fallback(decoded),
363
  "raw": decoded
364
  }
365
  else:
366
- # Fallback decode + channels
367
  decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
368
  channels = {
369
  "final": extract_final_channel_fallback(decoded),
@@ -373,7 +536,10 @@ def zerogpu_generate(full_prompt,
373
  return channels
374
 
375
  except Exception as e:
376
- return {"final": f"[Error] {type(e).__name__}: {str(e)}", "raw": str(e)}
 
 
 
377
  finally:
378
  # Cleanup
379
  try:
@@ -387,7 +553,6 @@ def zerogpu_generate(full_prompt,
387
  # -----------------------
388
  # Gradio handlers
389
  # -----------------------
390
-
391
  def generate_response(message: str, history: List[List[str]], system_prompt: str,
392
  temperature: float, top_p: float, top_k: int, max_new_tokens: int,
393
  do_sample: bool, seed: Optional[int],
@@ -395,14 +560,11 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
395
  rose_tokens: str, rose_json: str,
396
  show_thinking: bool = False,
397
  reasoning_effort: str = "high") -> str:
398
- """
399
- Generate response with proper CoT handling using Harmony format.
400
- """
401
  try:
402
- # Build message list
403
  messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
404
 
405
- # Add history
406
  if history:
407
  for turn in history:
408
  if isinstance(turn, (list, tuple)) and len(turn) >= 2:
@@ -412,17 +574,15 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
412
  if assistant_msg:
413
  messages.append({"role": "assistant", "content": str(assistant_msg)})
414
 
415
- # Add current message
416
  messages.append({"role": "user", "content": str(message)})
417
 
418
- # Create Harmony-formatted prompt
419
  if HARMONY_AVAILABLE:
420
- prompt = create_harmony_prompt(messages, reasoning_effort) # returns token IDs
421
  else:
422
- # Fallback to tokenizer template (string)
423
  prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
424
 
425
- # Build Rose map if enabled
426
  rose_map: Optional[Dict[str, float]] = None
427
  if rose_enable:
428
  rose_map = {}
@@ -449,7 +609,7 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
449
  if not rose_map:
450
  rose_map = None
451
 
452
- # Generate with model
453
  channels = zerogpu_generate(
454
  prompt,
455
  {
@@ -458,6 +618,8 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
458
  "top_p": float(top_p),
459
  "top_k": int(top_k) if top_k > 0 else None,
460
  "max_new_tokens": int(max_new_tokens),
 
 
461
  },
462
  rose_map,
463
  float(rose_alpha),
@@ -467,7 +629,6 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
467
 
468
  # Format response
469
  if show_thinking:
470
- # Show all channels
471
  response = "## Chain of Thought:\n\n"
472
  for channel, content in channels.items():
473
  if channel != "final" and content:
@@ -475,24 +636,25 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
475
  response += f"### Final Response:\n{channels.get('final', 'No final response generated')}"
476
  return response
477
  else:
478
- # Just show the final response
479
  return channels.get("final", "No final response generated")
480
 
481
  except Exception as e:
482
- return f"[Error] {type(e).__name__}: {str(e)}"
 
483
 
484
  # -----------------------
485
  # UI
486
  # -----------------------
487
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
488
  gr.Markdown(
489
- """
490
  # Mirel – Harmony Chain-of-Thought Inference
491
 
492
- OSS-20B model using Harmony format with thinking channels.
493
- The model thinks through problems in internal channels before providing a final response.
 
494
 
495
- **Note:** Install `openai-harmony` for full Harmony support: `pip install openai-harmony`
496
  """
497
  )
498
 
@@ -542,7 +704,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
542
  value=""
543
  )
544
 
545
- # Chat interface - using only valid parameters
546
  chat = gr.ChatInterface(
547
  fn=generate_response,
548
  type="messages",
@@ -552,7 +714,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
552
  rose_tokens, rose_json, show_thinking, reasoning_effort
553
  ],
554
  title="Chat with Mirel",
555
- description="A chain-of-thought model using Harmony format",
556
  examples=[
557
  ["Hello! Can you introduce yourself?"],
558
  ["What is the capital of France?"],
@@ -566,12 +728,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
566
  """
567
  ---
568
  ### Configuration:
569
- - **Model**: Set `MODEL_ID` env var (default: openai/gpt-oss-20b)
570
- - **Adapter**: Set `ADAPTER_ID` and optionally `ADAPTER_SUBFOLDER`
 
571
  - **Auth**: Set `HF_TOKEN` in Space secrets for private model access
572
- - **Harmony**: Install with `pip install openai-harmony` for proper channel support
573
-
574
- The model uses Harmony format with thinking channels (`thinking`, `analysis`, `final`).
575
  """
576
  )
577
 
@@ -580,4 +740,4 @@ if __name__ == "__main__":
580
  server_name="0.0.0.0",
581
  server_port=7860,
582
  share=False
583
- )
 
1
  """
2
  Mirel Harmony Inference – HF Space (Gradio)
3
+ ZeroGPU-ready, Harmony formatting, MX format support for GPT-OSS-20B
4
+ Proper LoRA adapter loading and conversion for MX compatibility
5
  Single file: app.py
6
  """
7
  from __future__ import annotations
8
+ import os, gc, json, threading, torch, warnings
9
  from dataclasses import dataclass
10
+ from typing import List, Dict, Optional, Any, Union
11
  from datetime import datetime
12
  import gradio as gr
13
  import spaces # required for ZeroGPU
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
15
+ import numpy as np
16
+
17
+ # Suppress warnings about MX format
18
+ warnings.filterwarnings("ignore", message=".*microscaling.*")
19
+ warnings.filterwarnings("ignore", message=".*mx.*")
20
 
21
  # Import Harmony components
22
  try:
 
39
  # -----------------------
40
  # Config & runtime modes
41
  # -----------------------
42
+ # MX format uses special dtypes - we need to handle this properly
 
43
  MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
44
  ADAPTER_ID = os.getenv("ADAPTER_ID") or None
45
  ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER") or None
46
  ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
 
47
  SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
48
  MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
49
  ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
50
+
51
+ # For GPT-OSS models, we need specific handling
52
+ IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower()
53
+ USE_MX_FORMAT = os.getenv("USE_MX_FORMAT", "1" if IS_GPT_OSS else "0") == "1"
54
 
55
  # Harmony channels for CoT
56
  REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
57
 
58
+ # HF Auth
59
  HF_TOKEN: Optional[str] = (
60
  os.getenv("HF_TOKEN")
61
  or os.getenv("HUGGING_FACE_HUB_TOKEN")
 
102
  raise
103
 
104
  # -----------------------
105
+ # PEFT and MX Format Support
106
  # -----------------------
107
  try:
108
+ from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
109
  _HAS_PEFT = True
110
  except Exception:
111
  _HAS_PEFT = False
112
+ print("[Warning] PEFT not available. Install with: pip install peft")
113
+
114
+ # Try to import microscaling support if available
115
+ try:
116
+ import msamp
117
+ _HAS_MSAMP = True
118
+ print("[Info] Microsoft AMP (msamp) available for MX format support")
119
+ except ImportError:
120
+ _HAS_MSAMP = False
121
+ print("[Info] msamp not available - using fallback MX handling")
122
+
123
+ # -----------------------
124
+ # MX Format Conversion
125
+ # -----------------------
126
+ def convert_fp32_lora_to_mx_compatible(lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
127
+ """
128
+ Convert fp32 LoRA weights to be compatible with MX format base model.
129
+ MX models expect specific dtype handling.
130
+ """
131
+ converted = {}
132
+
133
+ for key, tensor in lora_state_dict.items():
134
+ if tensor is None:
135
+ converted[key] = tensor
136
+ continue
137
+
138
+ # LoRA weights (lora_A, lora_B) need special handling
139
+ if 'lora_' in key:
140
+ # For MX compatibility, we keep weights in fp32 but ensure proper scaling
141
+ # MX format internally handles quantization, we just need clean fp32 inputs
142
+ if tensor.dtype != torch.float32:
143
+ tensor = tensor.to(torch.float32)
144
+
145
+ # Ensure weights are in reasonable range for MX quantization
146
+ # MX format works best with weights in [-1, 1] range
147
+ if 'lora_A' in key:
148
+ # Input projection - initialize with small values
149
+ std = 1.0 / torch.sqrt(torch.tensor(tensor.shape[1], dtype=torch.float32))
150
+ if tensor.std() > std * 10: # If weights are too large
151
+ print(f"[MX Convert] Scaling down {key} from std={tensor.std():.4f} to {std:.4f}")
152
+ tensor = tensor * (std / tensor.std())
153
+ elif 'lora_B' in key:
154
+ # Output projection - should be near zero initially
155
+ if tensor.abs().max() > 0.1:
156
+ print(f"[MX Convert] Scaling down {key} max={tensor.abs().max():.4f}")
157
+ tensor = tensor * 0.01
158
+
159
+ converted[key] = tensor
160
+ else:
161
+ # Non-LoRA weights (like embeddings) stay as-is
162
+ converted[key] = tensor
163
+
164
+ return converted
165
 
166
+ def prepare_model_for_mx_lora(model, adapter_path: str):
167
+ """
168
+ Prepare and attach LoRA adapter to MX format model.
169
+ Handles the special requirements of GPT-OSS MX models.
170
+ """
171
+ if not _HAS_PEFT:
172
+ raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft")
173
+
174
+ print(f"[LoRA] Loading adapter from {adapter_path}")
175
+
176
+ # Load the LoRA config
177
+ peft_config = PeftConfig.from_pretrained(adapter_path, token=HF_TOKEN)
178
+
179
+ # Load the LoRA weights
180
+ from safetensors.torch import load_file
181
+ import os.path as osp
182
+
183
+ adapter_weights_path = osp.join(adapter_path, "adapter_model.safetensors")
184
+ if not osp.exists(adapter_weights_path):
185
+ adapter_weights_path = osp.join(adapter_path, "adapter_model.bin")
186
+ if osp.exists(adapter_weights_path):
187
+ adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
188
+ else:
189
+ raise FileNotFoundError(f"No adapter weights found at {adapter_path}")
190
+ else:
191
+ adapter_weights = load_file(adapter_weights_path)
192
+
193
+ # Convert weights for MX compatibility
194
+ print("[LoRA] Converting fp32 weights for MX format compatibility...")
195
+ adapter_weights = convert_fp32_lora_to_mx_compatible(adapter_weights)
196
+
197
+ # Create PEFT model with special handling for MX
198
+ print("[LoRA] Attaching LoRA to base model...")
199
+
200
+ # For MX models, we need to be careful about dtype
201
+ # The base model uses MX format internally, but the interface should be fp32
202
+ model = PeftModel.from_pretrained(
203
+ model,
204
+ adapter_path,
205
+ is_trainable=False,
206
+ token=HF_TOKEN,
207
+ # Don't specify torch_dtype here - let it match the base model
208
+ )
209
+
210
+ # Manually update the adapter weights with our converted versions
211
+ model.load_state_dict(adapter_weights, strict=False)
212
+
213
+ print("[LoRA] Successfully attached LoRA adapter with MX compatibility")
214
+ return model
215
 
216
+ # -----------------------
217
+ # Model loading with MX support
218
+ # -----------------------
219
  def _build_model_kwargs(device_map: Optional[str]) -> Dict[str, Any]:
220
+ """Build kwargs for model loading with MX format support."""
221
  kw: Dict[str, Any] = dict(
 
222
  device_map=device_map,
 
223
  trust_remote_code=True,
224
  low_cpu_mem_usage=True,
225
  token=HF_TOKEN,
226
  )
227
+
228
+ if IS_GPT_OSS and USE_MX_FORMAT:
229
+ # GPT-OSS models use MX format
230
+ # Don't specify torch_dtype - let the model use its native MX format
231
+ print("[Model] Using MX format for GPT-OSS model")
232
+ kw.update({
233
+ "attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
234
+ # MX models handle their own dtype internally
235
+ # Don't force a dtype here
236
+ })
237
+ else:
238
+ # Non-MX models
239
+ kw.update({
240
+ "torch_dtype": torch.float16, # Use fp16 for non-MX models
241
+ "attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
242
+ })
243
+
244
  return kw
245
 
 
246
  def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
247
+ """Load model with proper MX format handling."""
248
  print(f"[Model] Loading base model from {MODEL_ID}...")
 
249
 
250
+ # Load config first to check for MX format
251
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
252
+
253
+ # Check if this is an MX model
254
+ is_mx_model = (
255
+ IS_GPT_OSS or
256
+ hasattr(config, 'quantization_config') and 'mx' in str(config.quantization_config).lower() or
257
+ hasattr(config, 'torch_dtype') and 'mx' in str(config.torch_dtype).lower()
258
+ )
259
+
260
+ if is_mx_model:
261
+ print("[Model] Detected MX format model - using special loading")
262
+
263
+ # For MX models, we need special handling
264
+ # The model internally uses MX quantization
265
+ model = AutoModelForCausalLM.from_pretrained(
266
+ MODEL_ID,
267
+ config=config,
268
+ trust_remote_code=True,
269
+ device_map=device_map,
270
+ low_cpu_mem_usage=True,
271
+ token=HF_TOKEN,
272
+ # Let the model handle its own dtype
273
+ attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager",
274
+ )
275
+
276
+ # Verify the model loaded correctly
277
+ print(f"[Model] Model dtype: {next(model.parameters()).dtype}")
278
+ print(f"[Model] Model device: {next(model.parameters()).device}")
279
+
280
+ else:
281
+ # Standard model loading
282
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **_build_model_kwargs(device_map))
283
+
284
+ # Load and attach LoRA adapter if specified
285
+ if ADAPTER_ID:
286
+ try:
287
+ if is_mx_model:
288
+ # Use special MX-compatible LoRA loading
289
+ model = prepare_model_for_mx_lora(model, ADAPTER_ID)
290
+ else:
291
+ # Standard PEFT loading for non-MX models
292
+ if not _HAS_PEFT:
293
+ raise RuntimeError("PEFT is required when ADAPTER_ID is set.")
294
+ print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...")
295
+ model = PeftModel.from_pretrained(
296
+ model,
297
+ ADAPTER_ID,
298
+ is_trainable=False,
299
+ token=HF_TOKEN
300
+ )
301
+
302
+ print("[Model] Successfully loaded with LoRA adapter")
303
+
304
+ # Optionally merge adapter for better performance
305
+ merge_adapter = os.getenv("MERGE_ADAPTER", "0") == "1"
306
+ if merge_adapter and hasattr(model, 'merge_and_unload'):
307
+ print("[Model] Merging adapter into base model...")
308
+ model = model.merge_and_unload()
309
+ print("[Model] Adapter merged successfully")
310
+
311
+ except Exception as e:
312
+ print(f"[Error] Failed to load adapter: {e}")
313
+ print("[Warning] Continuing with base model only")
314
 
315
  model.eval()
316
+
317
+ # Ensure proper config
318
  if getattr(model.config, "pad_token_id", None) is None:
319
  model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
320
  model.config.use_cache = True
321
+
322
+ print(f"[Model] Model loaded successfully - Type: {'MX Format' if is_mx_model else 'Standard'}")
323
  return model
324
 
325
  # -----------------------
326
  # Harmony formatting
327
  # -----------------------
 
328
  def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
329
+ """Build a Harmony-formatted prompt."""
 
 
 
330
  if HARMONY_AVAILABLE and harmony_encoding is not None:
331
  effort_map = {"low": ReasoningEffort.LOW, "medium": ReasoningEffort.MEDIUM, "high": ReasoningEffort.HIGH}
332
  effort = effort_map.get(str(reasoning_effort).lower(), ReasoningEffort.HIGH)
 
340
  .with_required_channels(REQUIRED_CHANNELS)
341
  )
342
 
 
343
  sys_text = SYSTEM_DEF
344
  rest: List[Dict[str, str]] = messages or []
345
  if rest and rest[0].get("role") == "system":
 
362
  convo = Conversation.from_messages(harmony_messages)
363
  return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
364
 
365
+ # Fallback: tokenizer chat template
366
  if not messages or messages[0].get("role") != "system":
367
  messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
368
  return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
 
370
  def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
371
  """Parse response tokens using Harmony format to extract channels."""
372
  if not HARMONY_AVAILABLE:
 
373
  text = tokenizer.decode(tokens, skip_special_tokens=False)
374
  return {"final": extract_final_channel_fallback(text), "raw": text}
375
 
 
376
  parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
377
 
 
378
  channels = {}
379
  for msg in parsed_messages:
380
  channel = msg.channel if hasattr(msg, 'channel') else "final"
 
382
  channels[channel] = ""
383
  channels[channel] += "".join([getattr(part, "text", str(part)) for part in (msg.content if isinstance(msg.content, list) else [msg.content])])
384
 
 
385
  if "final" not in channels:
386
  channels["final"] = " ".join(channels.values())
387
 
388
  return channels
389
 
390
  def extract_final_channel_fallback(text: str) -> str:
391
+ """Extract the <final> channel from decoded Harmony text."""
 
 
392
  try:
393
  chunks: Dict[str, str] = {}
394
  pieces = text.split("<|channel|>")
 
398
  continue
399
  ch = seg[:name_end].strip()
400
  body_start = name_end + len("<|message|>")
 
401
  next_pos = len(seg)
402
  for delim in ("<|channel|>", "<|end|>", "<|return|>"):
403
  p = seg.find(delim, body_start)
 
408
  final_txt = (chunks.get("final", "").strip())
409
  if final_txt:
410
  return final_txt
 
411
  if "<|channel|>final<|message|>" in text:
412
  tail = text.split("<|channel|>final<|message|>")[-1]
413
  for delim in ("<|return|>", "<|end|>", "<|channel|>"):
 
423
  # -----------------------
424
  # Rose guidance
425
  # -----------------------
 
426
  def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
427
  """Create vocab bias from {token: weight}."""
428
  vocab_size = len(tokenizer)
 
448
  def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
449
  return scores + self.alpha * self.bias_vec.to(scores.device)
450
 
451
+ # -----------------------
452
+ # Generation
453
+ # -----------------------
454
  @spaces.GPU(duration=120)
455
  def zerogpu_generate(full_prompt,
456
  gen_kwargs: Dict[str, Any],
 
458
  rose_alpha: float,
459
  rose_score: Optional[float],
460
  seed: Optional[int]) -> Dict[str, str]:
461
+ """Run inference on GPU with MX format support."""
462
  try:
463
  if seed is not None:
464
  torch.manual_seed(int(seed))
465
 
466
+ # Load model with MX support
467
  model = _load_model_on("auto")
468
 
469
  # Setup logits processor for Rose guidance
 
473
  eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
474
  logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
475
 
476
+ # Prepare inputs
477
  device = next(model.parameters()).device
478
  if HARMONY_AVAILABLE and isinstance(full_prompt, list):
479
  input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
 
484
  enc = tokenizer(full_prompt, return_tensors="pt")
485
  inputs = enc.to(device)
486
  prompt_len = int(inputs["input_ids"].shape[1])
 
487
  if "attention_mask" not in inputs:
488
  inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
489
+
490
  # Generate
 
491
  eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
492
 
493
  out_ids = model.generate(
 
505
  min_new_tokens=1,
506
  )
507
 
508
+ # Extract generated tokens
509
  out_list = out_ids[0].tolist()
510
  gen_ids = out_list[prompt_len:]
511
+
512
+ # Truncate at stop tokens
513
  if HARMONY_AVAILABLE:
514
  for sid in HARMONY_STOP_IDS:
515
  if sid in gen_ids:
516
  gen_ids = gen_ids[:gen_ids.index(sid)]
517
  break
518
 
519
+ # Parse response
520
  if HARMONY_AVAILABLE:
521
  try:
522
  channels = parse_harmony_response(gen_ids)
523
  except Exception:
 
524
  decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
525
  channels = {
526
  "final": extract_final_channel_fallback(decoded),
527
  "raw": decoded
528
  }
529
  else:
 
530
  decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
531
  channels = {
532
  "final": extract_final_channel_fallback(decoded),
 
536
  return channels
537
 
538
  except Exception as e:
539
+ import traceback
540
+ error_trace = traceback.format_exc()
541
+ print(f"[Error] Generation failed:\n{error_trace}")
542
+ return {"final": f"[Error] {type(e).__name__}: {str(e)}", "raw": error_trace}
543
  finally:
544
  # Cleanup
545
  try:
 
553
  # -----------------------
554
  # Gradio handlers
555
  # -----------------------
 
556
  def generate_response(message: str, history: List[List[str]], system_prompt: str,
557
  temperature: float, top_p: float, top_k: int, max_new_tokens: int,
558
  do_sample: bool, seed: Optional[int],
 
560
  rose_tokens: str, rose_json: str,
561
  show_thinking: bool = False,
562
  reasoning_effort: str = "high") -> str:
563
+ """Generate response with CoT handling."""
 
 
564
  try:
565
+ # Build messages
566
  messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
567
 
 
568
  if history:
569
  for turn in history:
570
  if isinstance(turn, (list, tuple)) and len(turn) >= 2:
 
574
  if assistant_msg:
575
  messages.append({"role": "assistant", "content": str(assistant_msg)})
576
 
 
577
  messages.append({"role": "user", "content": str(message)})
578
 
579
+ # Create prompt
580
  if HARMONY_AVAILABLE:
581
+ prompt = create_harmony_prompt(messages, reasoning_effort)
582
  else:
 
583
  prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
584
 
585
+ # Build Rose map
586
  rose_map: Optional[Dict[str, float]] = None
587
  if rose_enable:
588
  rose_map = {}
 
609
  if not rose_map:
610
  rose_map = None
611
 
612
+ # Generate
613
  channels = zerogpu_generate(
614
  prompt,
615
  {
 
618
  "top_p": float(top_p),
619
  "top_k": int(top_k) if top_k > 0 else None,
620
  "max_new_tokens": int(max_new_tokens),
621
+ "repetition_penalty": 1.1,
622
+ "no_repeat_ngram_size": 6,
623
  },
624
  rose_map,
625
  float(rose_alpha),
 
629
 
630
  # Format response
631
  if show_thinking:
 
632
  response = "## Chain of Thought:\n\n"
633
  for channel, content in channels.items():
634
  if channel != "final" and content:
 
636
  response += f"### Final Response:\n{channels.get('final', 'No final response generated')}"
637
  return response
638
  else:
 
639
  return channels.get("final", "No final response generated")
640
 
641
  except Exception as e:
642
+ import traceback
643
+ return f"[Error] {type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
644
 
645
  # -----------------------
646
  # UI
647
  # -----------------------
648
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
649
  gr.Markdown(
650
+ f"""
651
  # Mirel – Harmony Chain-of-Thought Inference
652
 
653
+ **Model**: {MODEL_ID} {'(MX Format)' if USE_MX_FORMAT else ''}
654
+ **Adapter**: {ADAPTER_ID or 'None'}
655
+ **Status**: {'✅ Harmony Available' if HARMONY_AVAILABLE else '⚠️ Harmony Not Installed'}
656
 
657
+ The model uses internal thinking channels before providing final responses.
658
  """
659
  )
660
 
 
704
  value=""
705
  )
706
 
707
+ # Chat interface
708
  chat = gr.ChatInterface(
709
  fn=generate_response,
710
  type="messages",
 
714
  rose_tokens, rose_json, show_thinking, reasoning_effort
715
  ],
716
  title="Chat with Mirel",
717
+ description="Chain-of-thought model with MX format support",
718
  examples=[
719
  ["Hello! Can you introduce yourself?"],
720
  ["What is the capital of France?"],
 
728
  """
729
  ---
730
  ### Configuration:
731
+ - **MX Format**: Automatically detected for GPT-OSS models
732
+ - **LoRA Support**: fp32 LoRA adapters are converted for MX compatibility
733
+ - **Merge Adapter**: Set `MERGE_ADAPTER=1` to merge LoRA into base model
734
  - **Auth**: Set `HF_TOKEN` in Space secrets for private model access
 
 
 
735
  """
736
  )
737
 
 
740
  server_name="0.0.0.0",
741
  server_port=7860,
742
  share=False
743
+ )