Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
·
5c9afc5
1
Parent(s):
a2f6c58
claude helps again
Browse files
app.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1 |
"""
|
2 |
Mirel Harmony Inference – HF Space (Gradio)
|
3 |
-
ZeroGPU-ready, Harmony formatting,
|
4 |
-
|
5 |
Single file: app.py
|
6 |
"""
|
7 |
from __future__ import annotations
|
8 |
-
import os, gc, json, threading, torch
|
9 |
from dataclasses import dataclass
|
10 |
-
from typing import List, Dict, Optional, Any
|
11 |
from datetime import datetime
|
12 |
import gradio as gr
|
13 |
import spaces # required for ZeroGPU
|
14 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# Import Harmony components
|
17 |
try:
|
@@ -34,22 +39,23 @@ except ImportError:
|
|
34 |
# -----------------------
|
35 |
# Config & runtime modes
|
36 |
# -----------------------
|
37 |
-
|
38 |
-
|
39 |
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
|
40 |
ADAPTER_ID = os.getenv("ADAPTER_ID") or None
|
41 |
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER") or None
|
42 |
ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
|
43 |
-
DTYPE = DTYPE_MAP.get(os.getenv("DTYPE", "bf16").lower(), torch.bfloat16)
|
44 |
SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
|
45 |
MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
|
46 |
ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
|
47 |
-
|
|
|
|
|
|
|
48 |
|
49 |
# Harmony channels for CoT
|
50 |
REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
|
51 |
|
52 |
-
# HF Auth
|
53 |
HF_TOKEN: Optional[str] = (
|
54 |
os.getenv("HF_TOKEN")
|
55 |
or os.getenv("HUGGING_FACE_HUB_TOKEN")
|
@@ -96,65 +102,231 @@ except Exception as e:
|
|
96 |
raise
|
97 |
|
98 |
# -----------------------
|
99 |
-
#
|
100 |
# -----------------------
|
101 |
try:
|
102 |
-
from peft import PeftModel
|
103 |
_HAS_PEFT = True
|
104 |
except Exception:
|
105 |
_HAS_PEFT = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
|
|
|
|
|
|
108 |
def _build_model_kwargs(device_map: Optional[str]) -> Dict[str, Any]:
|
|
|
109 |
kw: Dict[str, Any] = dict(
|
110 |
-
torch_dtype=DTYPE,
|
111 |
device_map=device_map,
|
112 |
-
attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager",
|
113 |
trust_remote_code=True,
|
114 |
low_cpu_mem_usage=True,
|
115 |
token=HF_TOKEN,
|
116 |
)
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
return kw
|
126 |
|
127 |
-
|
128 |
def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
|
|
|
129 |
print(f"[Model] Loading base model from {MODEL_ID}...")
|
130 |
-
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **_build_model_kwargs(device_map))
|
131 |
|
132 |
-
#
|
133 |
-
|
134 |
-
|
135 |
-
#
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
model.eval()
|
142 |
-
|
|
|
143 |
if getattr(model.config, "pad_token_id", None) is None:
|
144 |
model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
145 |
model.config.use_cache = True
|
146 |
-
|
|
|
147 |
return model
|
148 |
|
149 |
# -----------------------
|
150 |
# Harmony formatting
|
151 |
# -----------------------
|
152 |
-
|
153 |
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
|
154 |
-
"""Build a Harmony-formatted prompt.
|
155 |
-
rendered by `openai_harmony` (authoritative). Otherwise fall back to the
|
156 |
-
tokenizer's chat template and return a string.
|
157 |
-
"""
|
158 |
if HARMONY_AVAILABLE and harmony_encoding is not None:
|
159 |
effort_map = {"low": ReasoningEffort.LOW, "medium": ReasoningEffort.MEDIUM, "high": ReasoningEffort.HIGH}
|
160 |
effort = effort_map.get(str(reasoning_effort).lower(), ReasoningEffort.HIGH)
|
@@ -168,7 +340,6 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
|
|
168 |
.with_required_channels(REQUIRED_CHANNELS)
|
169 |
)
|
170 |
|
171 |
-
# Use first system message as developer instructions if present, else SYSTEM_DEF
|
172 |
sys_text = SYSTEM_DEF
|
173 |
rest: List[Dict[str, str]] = messages or []
|
174 |
if rest and rest[0].get("role") == "system":
|
@@ -191,7 +362,7 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
|
|
191 |
convo = Conversation.from_messages(harmony_messages)
|
192 |
return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
|
193 |
|
194 |
-
# Fallback: tokenizer chat template
|
195 |
if not messages or messages[0].get("role") != "system":
|
196 |
messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
|
197 |
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
@@ -199,14 +370,11 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
|
|
199 |
def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
|
200 |
"""Parse response tokens using Harmony format to extract channels."""
|
201 |
if not HARMONY_AVAILABLE:
|
202 |
-
# Fallback: just decode and extract final channel manually
|
203 |
text = tokenizer.decode(tokens, skip_special_tokens=False)
|
204 |
return {"final": extract_final_channel_fallback(text), "raw": text}
|
205 |
|
206 |
-
# Parse messages from completion tokens
|
207 |
parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
|
208 |
|
209 |
-
# Extract content by channel
|
210 |
channels = {}
|
211 |
for msg in parsed_messages:
|
212 |
channel = msg.channel if hasattr(msg, 'channel') else "final"
|
@@ -214,16 +382,13 @@ def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
|
|
214 |
channels[channel] = ""
|
215 |
channels[channel] += "".join([getattr(part, "text", str(part)) for part in (msg.content if isinstance(msg.content, list) else [msg.content])])
|
216 |
|
217 |
-
# Ensure we have a final channel
|
218 |
if "final" not in channels:
|
219 |
channels["final"] = " ".join(channels.values())
|
220 |
|
221 |
return channels
|
222 |
|
223 |
def extract_final_channel_fallback(text: str) -> str:
|
224 |
-
"""
|
225 |
-
Works even if parsing fails or the model emits extra headers.
|
226 |
-
"""
|
227 |
try:
|
228 |
chunks: Dict[str, str] = {}
|
229 |
pieces = text.split("<|channel|>")
|
@@ -233,7 +398,6 @@ def extract_final_channel_fallback(text: str) -> str:
|
|
233 |
continue
|
234 |
ch = seg[:name_end].strip()
|
235 |
body_start = name_end + len("<|message|>")
|
236 |
-
# end at next channel/end/return marker
|
237 |
next_pos = len(seg)
|
238 |
for delim in ("<|channel|>", "<|end|>", "<|return|>"):
|
239 |
p = seg.find(delim, body_start)
|
@@ -244,7 +408,6 @@ def extract_final_channel_fallback(text: str) -> str:
|
|
244 |
final_txt = (chunks.get("final", "").strip())
|
245 |
if final_txt:
|
246 |
return final_txt
|
247 |
-
# Fallback: everything after last final marker up to a terminator
|
248 |
if "<|channel|>final<|message|>" in text:
|
249 |
tail = text.split("<|channel|>final<|message|>")[-1]
|
250 |
for delim in ("<|return|>", "<|end|>", "<|channel|>"):
|
@@ -260,7 +423,6 @@ def extract_final_channel_fallback(text: str) -> str:
|
|
260 |
# -----------------------
|
261 |
# Rose guidance
|
262 |
# -----------------------
|
263 |
-
|
264 |
def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
|
265 |
"""Create vocab bias from {token: weight}."""
|
266 |
vocab_size = len(tokenizer)
|
@@ -286,6 +448,9 @@ class RoseGuidedLogits(torch.nn.Module):
|
|
286 |
def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
287 |
return scores + self.alpha * self.bias_vec.to(scores.device)
|
288 |
|
|
|
|
|
|
|
289 |
@spaces.GPU(duration=120)
|
290 |
def zerogpu_generate(full_prompt,
|
291 |
gen_kwargs: Dict[str, Any],
|
@@ -293,12 +458,12 @@ def zerogpu_generate(full_prompt,
|
|
293 |
rose_alpha: float,
|
294 |
rose_score: Optional[float],
|
295 |
seed: Optional[int]) -> Dict[str, str]:
|
296 |
-
"""Run inference on GPU
|
297 |
try:
|
298 |
if seed is not None:
|
299 |
torch.manual_seed(int(seed))
|
300 |
|
301 |
-
# Load model
|
302 |
model = _load_model_on("auto")
|
303 |
|
304 |
# Setup logits processor for Rose guidance
|
@@ -308,7 +473,7 @@ def zerogpu_generate(full_prompt,
|
|
308 |
eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
|
309 |
logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
|
310 |
|
311 |
-
#
|
312 |
device = next(model.parameters()).device
|
313 |
if HARMONY_AVAILABLE and isinstance(full_prompt, list):
|
314 |
input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
|
@@ -319,11 +484,10 @@ def zerogpu_generate(full_prompt,
|
|
319 |
enc = tokenizer(full_prompt, return_tensors="pt")
|
320 |
inputs = enc.to(device)
|
321 |
prompt_len = int(inputs["input_ids"].shape[1])
|
322 |
-
# Guarantee attention_mask exists; avoids pad==eos ambiguity warnings
|
323 |
if "attention_mask" not in inputs:
|
324 |
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
|
|
|
325 |
# Generate
|
326 |
-
# Build EOS list: use ONLY Harmony assistant-action stops (per OpenAI docs)
|
327 |
eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
|
328 |
|
329 |
out_ids = model.generate(
|
@@ -341,29 +505,28 @@ def zerogpu_generate(full_prompt,
|
|
341 |
min_new_tokens=1,
|
342 |
)
|
343 |
|
344 |
-
# Extract generated tokens
|
345 |
out_list = out_ids[0].tolist()
|
346 |
gen_ids = out_list[prompt_len:]
|
347 |
-
|
|
|
348 |
if HARMONY_AVAILABLE:
|
349 |
for sid in HARMONY_STOP_IDS:
|
350 |
if sid in gen_ids:
|
351 |
gen_ids = gen_ids[:gen_ids.index(sid)]
|
352 |
break
|
353 |
|
354 |
-
# Parse response
|
355 |
if HARMONY_AVAILABLE:
|
356 |
try:
|
357 |
channels = parse_harmony_response(gen_ids)
|
358 |
except Exception:
|
359 |
-
# Fallback to text parsing if Harmony parser fails
|
360 |
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
361 |
channels = {
|
362 |
"final": extract_final_channel_fallback(decoded),
|
363 |
"raw": decoded
|
364 |
}
|
365 |
else:
|
366 |
-
# Fallback decode + channels
|
367 |
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
368 |
channels = {
|
369 |
"final": extract_final_channel_fallback(decoded),
|
@@ -373,7 +536,10 @@ def zerogpu_generate(full_prompt,
|
|
373 |
return channels
|
374 |
|
375 |
except Exception as e:
|
376 |
-
|
|
|
|
|
|
|
377 |
finally:
|
378 |
# Cleanup
|
379 |
try:
|
@@ -387,7 +553,6 @@ def zerogpu_generate(full_prompt,
|
|
387 |
# -----------------------
|
388 |
# Gradio handlers
|
389 |
# -----------------------
|
390 |
-
|
391 |
def generate_response(message: str, history: List[List[str]], system_prompt: str,
|
392 |
temperature: float, top_p: float, top_k: int, max_new_tokens: int,
|
393 |
do_sample: bool, seed: Optional[int],
|
@@ -395,14 +560,11 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
|
|
395 |
rose_tokens: str, rose_json: str,
|
396 |
show_thinking: bool = False,
|
397 |
reasoning_effort: str = "high") -> str:
|
398 |
-
"""
|
399 |
-
Generate response with proper CoT handling using Harmony format.
|
400 |
-
"""
|
401 |
try:
|
402 |
-
# Build
|
403 |
messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
|
404 |
|
405 |
-
# Add history
|
406 |
if history:
|
407 |
for turn in history:
|
408 |
if isinstance(turn, (list, tuple)) and len(turn) >= 2:
|
@@ -412,17 +574,15 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
|
|
412 |
if assistant_msg:
|
413 |
messages.append({"role": "assistant", "content": str(assistant_msg)})
|
414 |
|
415 |
-
# Add current message
|
416 |
messages.append({"role": "user", "content": str(message)})
|
417 |
|
418 |
-
# Create
|
419 |
if HARMONY_AVAILABLE:
|
420 |
-
prompt = create_harmony_prompt(messages, reasoning_effort)
|
421 |
else:
|
422 |
-
# Fallback to tokenizer template (string)
|
423 |
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
424 |
|
425 |
-
# Build Rose map
|
426 |
rose_map: Optional[Dict[str, float]] = None
|
427 |
if rose_enable:
|
428 |
rose_map = {}
|
@@ -449,7 +609,7 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
|
|
449 |
if not rose_map:
|
450 |
rose_map = None
|
451 |
|
452 |
-
# Generate
|
453 |
channels = zerogpu_generate(
|
454 |
prompt,
|
455 |
{
|
@@ -458,6 +618,8 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
|
|
458 |
"top_p": float(top_p),
|
459 |
"top_k": int(top_k) if top_k > 0 else None,
|
460 |
"max_new_tokens": int(max_new_tokens),
|
|
|
|
|
461 |
},
|
462 |
rose_map,
|
463 |
float(rose_alpha),
|
@@ -467,7 +629,6 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
|
|
467 |
|
468 |
# Format response
|
469 |
if show_thinking:
|
470 |
-
# Show all channels
|
471 |
response = "## Chain of Thought:\n\n"
|
472 |
for channel, content in channels.items():
|
473 |
if channel != "final" and content:
|
@@ -475,24 +636,25 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
|
|
475 |
response += f"### Final Response:\n{channels.get('final', 'No final response generated')}"
|
476 |
return response
|
477 |
else:
|
478 |
-
# Just show the final response
|
479 |
return channels.get("final", "No final response generated")
|
480 |
|
481 |
except Exception as e:
|
482 |
-
|
|
|
483 |
|
484 |
# -----------------------
|
485 |
# UI
|
486 |
# -----------------------
|
487 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
488 |
gr.Markdown(
|
489 |
-
"""
|
490 |
# Mirel – Harmony Chain-of-Thought Inference
|
491 |
|
492 |
-
|
493 |
-
|
|
|
494 |
|
495 |
-
|
496 |
"""
|
497 |
)
|
498 |
|
@@ -542,7 +704,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
542 |
value=""
|
543 |
)
|
544 |
|
545 |
-
# Chat interface
|
546 |
chat = gr.ChatInterface(
|
547 |
fn=generate_response,
|
548 |
type="messages",
|
@@ -552,7 +714,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
552 |
rose_tokens, rose_json, show_thinking, reasoning_effort
|
553 |
],
|
554 |
title="Chat with Mirel",
|
555 |
-
description="
|
556 |
examples=[
|
557 |
["Hello! Can you introduce yourself?"],
|
558 |
["What is the capital of France?"],
|
@@ -566,12 +728,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
566 |
"""
|
567 |
---
|
568 |
### Configuration:
|
569 |
-
- **
|
570 |
-
- **
|
|
|
571 |
- **Auth**: Set `HF_TOKEN` in Space secrets for private model access
|
572 |
-
- **Harmony**: Install with `pip install openai-harmony` for proper channel support
|
573 |
-
|
574 |
-
The model uses Harmony format with thinking channels (`thinking`, `analysis`, `final`).
|
575 |
"""
|
576 |
)
|
577 |
|
@@ -580,4 +740,4 @@ if __name__ == "__main__":
|
|
580 |
server_name="0.0.0.0",
|
581 |
server_port=7860,
|
582 |
share=False
|
583 |
-
)
|
|
|
1 |
"""
|
2 |
Mirel Harmony Inference – HF Space (Gradio)
|
3 |
+
ZeroGPU-ready, Harmony formatting, MX format support for GPT-OSS-20B
|
4 |
+
Proper LoRA adapter loading and conversion for MX compatibility
|
5 |
Single file: app.py
|
6 |
"""
|
7 |
from __future__ import annotations
|
8 |
+
import os, gc, json, threading, torch, warnings
|
9 |
from dataclasses import dataclass
|
10 |
+
from typing import List, Dict, Optional, Any, Union
|
11 |
from datetime import datetime
|
12 |
import gradio as gr
|
13 |
import spaces # required for ZeroGPU
|
14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
# Suppress warnings about MX format
|
18 |
+
warnings.filterwarnings("ignore", message=".*microscaling.*")
|
19 |
+
warnings.filterwarnings("ignore", message=".*mx.*")
|
20 |
|
21 |
# Import Harmony components
|
22 |
try:
|
|
|
39 |
# -----------------------
|
40 |
# Config & runtime modes
|
41 |
# -----------------------
|
42 |
+
# MX format uses special dtypes - we need to handle this properly
|
|
|
43 |
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
|
44 |
ADAPTER_ID = os.getenv("ADAPTER_ID") or None
|
45 |
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER") or None
|
46 |
ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
|
|
|
47 |
SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
|
48 |
MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
|
49 |
ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
|
50 |
+
|
51 |
+
# For GPT-OSS models, we need specific handling
|
52 |
+
IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower()
|
53 |
+
USE_MX_FORMAT = os.getenv("USE_MX_FORMAT", "1" if IS_GPT_OSS else "0") == "1"
|
54 |
|
55 |
# Harmony channels for CoT
|
56 |
REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
|
57 |
|
58 |
+
# HF Auth
|
59 |
HF_TOKEN: Optional[str] = (
|
60 |
os.getenv("HF_TOKEN")
|
61 |
or os.getenv("HUGGING_FACE_HUB_TOKEN")
|
|
|
102 |
raise
|
103 |
|
104 |
# -----------------------
|
105 |
+
# PEFT and MX Format Support
|
106 |
# -----------------------
|
107 |
try:
|
108 |
+
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
|
109 |
_HAS_PEFT = True
|
110 |
except Exception:
|
111 |
_HAS_PEFT = False
|
112 |
+
print("[Warning] PEFT not available. Install with: pip install peft")
|
113 |
+
|
114 |
+
# Try to import microscaling support if available
|
115 |
+
try:
|
116 |
+
import msamp
|
117 |
+
_HAS_MSAMP = True
|
118 |
+
print("[Info] Microsoft AMP (msamp) available for MX format support")
|
119 |
+
except ImportError:
|
120 |
+
_HAS_MSAMP = False
|
121 |
+
print("[Info] msamp not available - using fallback MX handling")
|
122 |
+
|
123 |
+
# -----------------------
|
124 |
+
# MX Format Conversion
|
125 |
+
# -----------------------
|
126 |
+
def convert_fp32_lora_to_mx_compatible(lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
127 |
+
"""
|
128 |
+
Convert fp32 LoRA weights to be compatible with MX format base model.
|
129 |
+
MX models expect specific dtype handling.
|
130 |
+
"""
|
131 |
+
converted = {}
|
132 |
+
|
133 |
+
for key, tensor in lora_state_dict.items():
|
134 |
+
if tensor is None:
|
135 |
+
converted[key] = tensor
|
136 |
+
continue
|
137 |
+
|
138 |
+
# LoRA weights (lora_A, lora_B) need special handling
|
139 |
+
if 'lora_' in key:
|
140 |
+
# For MX compatibility, we keep weights in fp32 but ensure proper scaling
|
141 |
+
# MX format internally handles quantization, we just need clean fp32 inputs
|
142 |
+
if tensor.dtype != torch.float32:
|
143 |
+
tensor = tensor.to(torch.float32)
|
144 |
+
|
145 |
+
# Ensure weights are in reasonable range for MX quantization
|
146 |
+
# MX format works best with weights in [-1, 1] range
|
147 |
+
if 'lora_A' in key:
|
148 |
+
# Input projection - initialize with small values
|
149 |
+
std = 1.0 / torch.sqrt(torch.tensor(tensor.shape[1], dtype=torch.float32))
|
150 |
+
if tensor.std() > std * 10: # If weights are too large
|
151 |
+
print(f"[MX Convert] Scaling down {key} from std={tensor.std():.4f} to {std:.4f}")
|
152 |
+
tensor = tensor * (std / tensor.std())
|
153 |
+
elif 'lora_B' in key:
|
154 |
+
# Output projection - should be near zero initially
|
155 |
+
if tensor.abs().max() > 0.1:
|
156 |
+
print(f"[MX Convert] Scaling down {key} max={tensor.abs().max():.4f}")
|
157 |
+
tensor = tensor * 0.01
|
158 |
+
|
159 |
+
converted[key] = tensor
|
160 |
+
else:
|
161 |
+
# Non-LoRA weights (like embeddings) stay as-is
|
162 |
+
converted[key] = tensor
|
163 |
+
|
164 |
+
return converted
|
165 |
|
166 |
+
def prepare_model_for_mx_lora(model, adapter_path: str):
|
167 |
+
"""
|
168 |
+
Prepare and attach LoRA adapter to MX format model.
|
169 |
+
Handles the special requirements of GPT-OSS MX models.
|
170 |
+
"""
|
171 |
+
if not _HAS_PEFT:
|
172 |
+
raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft")
|
173 |
+
|
174 |
+
print(f"[LoRA] Loading adapter from {adapter_path}")
|
175 |
+
|
176 |
+
# Load the LoRA config
|
177 |
+
peft_config = PeftConfig.from_pretrained(adapter_path, token=HF_TOKEN)
|
178 |
+
|
179 |
+
# Load the LoRA weights
|
180 |
+
from safetensors.torch import load_file
|
181 |
+
import os.path as osp
|
182 |
+
|
183 |
+
adapter_weights_path = osp.join(adapter_path, "adapter_model.safetensors")
|
184 |
+
if not osp.exists(adapter_weights_path):
|
185 |
+
adapter_weights_path = osp.join(adapter_path, "adapter_model.bin")
|
186 |
+
if osp.exists(adapter_weights_path):
|
187 |
+
adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
|
188 |
+
else:
|
189 |
+
raise FileNotFoundError(f"No adapter weights found at {adapter_path}")
|
190 |
+
else:
|
191 |
+
adapter_weights = load_file(adapter_weights_path)
|
192 |
+
|
193 |
+
# Convert weights for MX compatibility
|
194 |
+
print("[LoRA] Converting fp32 weights for MX format compatibility...")
|
195 |
+
adapter_weights = convert_fp32_lora_to_mx_compatible(adapter_weights)
|
196 |
+
|
197 |
+
# Create PEFT model with special handling for MX
|
198 |
+
print("[LoRA] Attaching LoRA to base model...")
|
199 |
+
|
200 |
+
# For MX models, we need to be careful about dtype
|
201 |
+
# The base model uses MX format internally, but the interface should be fp32
|
202 |
+
model = PeftModel.from_pretrained(
|
203 |
+
model,
|
204 |
+
adapter_path,
|
205 |
+
is_trainable=False,
|
206 |
+
token=HF_TOKEN,
|
207 |
+
# Don't specify torch_dtype here - let it match the base model
|
208 |
+
)
|
209 |
+
|
210 |
+
# Manually update the adapter weights with our converted versions
|
211 |
+
model.load_state_dict(adapter_weights, strict=False)
|
212 |
+
|
213 |
+
print("[LoRA] Successfully attached LoRA adapter with MX compatibility")
|
214 |
+
return model
|
215 |
|
216 |
+
# -----------------------
|
217 |
+
# Model loading with MX support
|
218 |
+
# -----------------------
|
219 |
def _build_model_kwargs(device_map: Optional[str]) -> Dict[str, Any]:
|
220 |
+
"""Build kwargs for model loading with MX format support."""
|
221 |
kw: Dict[str, Any] = dict(
|
|
|
222 |
device_map=device_map,
|
|
|
223 |
trust_remote_code=True,
|
224 |
low_cpu_mem_usage=True,
|
225 |
token=HF_TOKEN,
|
226 |
)
|
227 |
+
|
228 |
+
if IS_GPT_OSS and USE_MX_FORMAT:
|
229 |
+
# GPT-OSS models use MX format
|
230 |
+
# Don't specify torch_dtype - let the model use its native MX format
|
231 |
+
print("[Model] Using MX format for GPT-OSS model")
|
232 |
+
kw.update({
|
233 |
+
"attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
|
234 |
+
# MX models handle their own dtype internally
|
235 |
+
# Don't force a dtype here
|
236 |
+
})
|
237 |
+
else:
|
238 |
+
# Non-MX models
|
239 |
+
kw.update({
|
240 |
+
"torch_dtype": torch.float16, # Use fp16 for non-MX models
|
241 |
+
"attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
|
242 |
+
})
|
243 |
+
|
244 |
return kw
|
245 |
|
|
|
246 |
def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
|
247 |
+
"""Load model with proper MX format handling."""
|
248 |
print(f"[Model] Loading base model from {MODEL_ID}...")
|
|
|
249 |
|
250 |
+
# Load config first to check for MX format
|
251 |
+
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
|
252 |
+
|
253 |
+
# Check if this is an MX model
|
254 |
+
is_mx_model = (
|
255 |
+
IS_GPT_OSS or
|
256 |
+
hasattr(config, 'quantization_config') and 'mx' in str(config.quantization_config).lower() or
|
257 |
+
hasattr(config, 'torch_dtype') and 'mx' in str(config.torch_dtype).lower()
|
258 |
+
)
|
259 |
+
|
260 |
+
if is_mx_model:
|
261 |
+
print("[Model] Detected MX format model - using special loading")
|
262 |
+
|
263 |
+
# For MX models, we need special handling
|
264 |
+
# The model internally uses MX quantization
|
265 |
+
model = AutoModelForCausalLM.from_pretrained(
|
266 |
+
MODEL_ID,
|
267 |
+
config=config,
|
268 |
+
trust_remote_code=True,
|
269 |
+
device_map=device_map,
|
270 |
+
low_cpu_mem_usage=True,
|
271 |
+
token=HF_TOKEN,
|
272 |
+
# Let the model handle its own dtype
|
273 |
+
attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager",
|
274 |
+
)
|
275 |
+
|
276 |
+
# Verify the model loaded correctly
|
277 |
+
print(f"[Model] Model dtype: {next(model.parameters()).dtype}")
|
278 |
+
print(f"[Model] Model device: {next(model.parameters()).device}")
|
279 |
+
|
280 |
+
else:
|
281 |
+
# Standard model loading
|
282 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **_build_model_kwargs(device_map))
|
283 |
+
|
284 |
+
# Load and attach LoRA adapter if specified
|
285 |
+
if ADAPTER_ID:
|
286 |
+
try:
|
287 |
+
if is_mx_model:
|
288 |
+
# Use special MX-compatible LoRA loading
|
289 |
+
model = prepare_model_for_mx_lora(model, ADAPTER_ID)
|
290 |
+
else:
|
291 |
+
# Standard PEFT loading for non-MX models
|
292 |
+
if not _HAS_PEFT:
|
293 |
+
raise RuntimeError("PEFT is required when ADAPTER_ID is set.")
|
294 |
+
print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...")
|
295 |
+
model = PeftModel.from_pretrained(
|
296 |
+
model,
|
297 |
+
ADAPTER_ID,
|
298 |
+
is_trainable=False,
|
299 |
+
token=HF_TOKEN
|
300 |
+
)
|
301 |
+
|
302 |
+
print("[Model] Successfully loaded with LoRA adapter")
|
303 |
+
|
304 |
+
# Optionally merge adapter for better performance
|
305 |
+
merge_adapter = os.getenv("MERGE_ADAPTER", "0") == "1"
|
306 |
+
if merge_adapter and hasattr(model, 'merge_and_unload'):
|
307 |
+
print("[Model] Merging adapter into base model...")
|
308 |
+
model = model.merge_and_unload()
|
309 |
+
print("[Model] Adapter merged successfully")
|
310 |
+
|
311 |
+
except Exception as e:
|
312 |
+
print(f"[Error] Failed to load adapter: {e}")
|
313 |
+
print("[Warning] Continuing with base model only")
|
314 |
|
315 |
model.eval()
|
316 |
+
|
317 |
+
# Ensure proper config
|
318 |
if getattr(model.config, "pad_token_id", None) is None:
|
319 |
model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
320 |
model.config.use_cache = True
|
321 |
+
|
322 |
+
print(f"[Model] Model loaded successfully - Type: {'MX Format' if is_mx_model else 'Standard'}")
|
323 |
return model
|
324 |
|
325 |
# -----------------------
|
326 |
# Harmony formatting
|
327 |
# -----------------------
|
|
|
328 |
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
|
329 |
+
"""Build a Harmony-formatted prompt."""
|
|
|
|
|
|
|
330 |
if HARMONY_AVAILABLE and harmony_encoding is not None:
|
331 |
effort_map = {"low": ReasoningEffort.LOW, "medium": ReasoningEffort.MEDIUM, "high": ReasoningEffort.HIGH}
|
332 |
effort = effort_map.get(str(reasoning_effort).lower(), ReasoningEffort.HIGH)
|
|
|
340 |
.with_required_channels(REQUIRED_CHANNELS)
|
341 |
)
|
342 |
|
|
|
343 |
sys_text = SYSTEM_DEF
|
344 |
rest: List[Dict[str, str]] = messages or []
|
345 |
if rest and rest[0].get("role") == "system":
|
|
|
362 |
convo = Conversation.from_messages(harmony_messages)
|
363 |
return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
|
364 |
|
365 |
+
# Fallback: tokenizer chat template
|
366 |
if not messages or messages[0].get("role") != "system":
|
367 |
messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
|
368 |
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
|
|
370 |
def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
|
371 |
"""Parse response tokens using Harmony format to extract channels."""
|
372 |
if not HARMONY_AVAILABLE:
|
|
|
373 |
text = tokenizer.decode(tokens, skip_special_tokens=False)
|
374 |
return {"final": extract_final_channel_fallback(text), "raw": text}
|
375 |
|
|
|
376 |
parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
|
377 |
|
|
|
378 |
channels = {}
|
379 |
for msg in parsed_messages:
|
380 |
channel = msg.channel if hasattr(msg, 'channel') else "final"
|
|
|
382 |
channels[channel] = ""
|
383 |
channels[channel] += "".join([getattr(part, "text", str(part)) for part in (msg.content if isinstance(msg.content, list) else [msg.content])])
|
384 |
|
|
|
385 |
if "final" not in channels:
|
386 |
channels["final"] = " ".join(channels.values())
|
387 |
|
388 |
return channels
|
389 |
|
390 |
def extract_final_channel_fallback(text: str) -> str:
|
391 |
+
"""Extract the <final> channel from decoded Harmony text."""
|
|
|
|
|
392 |
try:
|
393 |
chunks: Dict[str, str] = {}
|
394 |
pieces = text.split("<|channel|>")
|
|
|
398 |
continue
|
399 |
ch = seg[:name_end].strip()
|
400 |
body_start = name_end + len("<|message|>")
|
|
|
401 |
next_pos = len(seg)
|
402 |
for delim in ("<|channel|>", "<|end|>", "<|return|>"):
|
403 |
p = seg.find(delim, body_start)
|
|
|
408 |
final_txt = (chunks.get("final", "").strip())
|
409 |
if final_txt:
|
410 |
return final_txt
|
|
|
411 |
if "<|channel|>final<|message|>" in text:
|
412 |
tail = text.split("<|channel|>final<|message|>")[-1]
|
413 |
for delim in ("<|return|>", "<|end|>", "<|channel|>"):
|
|
|
423 |
# -----------------------
|
424 |
# Rose guidance
|
425 |
# -----------------------
|
|
|
426 |
def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
|
427 |
"""Create vocab bias from {token: weight}."""
|
428 |
vocab_size = len(tokenizer)
|
|
|
448 |
def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
449 |
return scores + self.alpha * self.bias_vec.to(scores.device)
|
450 |
|
451 |
+
# -----------------------
|
452 |
+
# Generation
|
453 |
+
# -----------------------
|
454 |
@spaces.GPU(duration=120)
|
455 |
def zerogpu_generate(full_prompt,
|
456 |
gen_kwargs: Dict[str, Any],
|
|
|
458 |
rose_alpha: float,
|
459 |
rose_score: Optional[float],
|
460 |
seed: Optional[int]) -> Dict[str, str]:
|
461 |
+
"""Run inference on GPU with MX format support."""
|
462 |
try:
|
463 |
if seed is not None:
|
464 |
torch.manual_seed(int(seed))
|
465 |
|
466 |
+
# Load model with MX support
|
467 |
model = _load_model_on("auto")
|
468 |
|
469 |
# Setup logits processor for Rose guidance
|
|
|
473 |
eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
|
474 |
logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
|
475 |
|
476 |
+
# Prepare inputs
|
477 |
device = next(model.parameters()).device
|
478 |
if HARMONY_AVAILABLE and isinstance(full_prompt, list):
|
479 |
input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
|
|
|
484 |
enc = tokenizer(full_prompt, return_tensors="pt")
|
485 |
inputs = enc.to(device)
|
486 |
prompt_len = int(inputs["input_ids"].shape[1])
|
|
|
487 |
if "attention_mask" not in inputs:
|
488 |
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
|
489 |
+
|
490 |
# Generate
|
|
|
491 |
eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
|
492 |
|
493 |
out_ids = model.generate(
|
|
|
505 |
min_new_tokens=1,
|
506 |
)
|
507 |
|
508 |
+
# Extract generated tokens
|
509 |
out_list = out_ids[0].tolist()
|
510 |
gen_ids = out_list[prompt_len:]
|
511 |
+
|
512 |
+
# Truncate at stop tokens
|
513 |
if HARMONY_AVAILABLE:
|
514 |
for sid in HARMONY_STOP_IDS:
|
515 |
if sid in gen_ids:
|
516 |
gen_ids = gen_ids[:gen_ids.index(sid)]
|
517 |
break
|
518 |
|
519 |
+
# Parse response
|
520 |
if HARMONY_AVAILABLE:
|
521 |
try:
|
522 |
channels = parse_harmony_response(gen_ids)
|
523 |
except Exception:
|
|
|
524 |
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
525 |
channels = {
|
526 |
"final": extract_final_channel_fallback(decoded),
|
527 |
"raw": decoded
|
528 |
}
|
529 |
else:
|
|
|
530 |
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
531 |
channels = {
|
532 |
"final": extract_final_channel_fallback(decoded),
|
|
|
536 |
return channels
|
537 |
|
538 |
except Exception as e:
|
539 |
+
import traceback
|
540 |
+
error_trace = traceback.format_exc()
|
541 |
+
print(f"[Error] Generation failed:\n{error_trace}")
|
542 |
+
return {"final": f"[Error] {type(e).__name__}: {str(e)}", "raw": error_trace}
|
543 |
finally:
|
544 |
# Cleanup
|
545 |
try:
|
|
|
553 |
# -----------------------
|
554 |
# Gradio handlers
|
555 |
# -----------------------
|
|
|
556 |
def generate_response(message: str, history: List[List[str]], system_prompt: str,
|
557 |
temperature: float, top_p: float, top_k: int, max_new_tokens: int,
|
558 |
do_sample: bool, seed: Optional[int],
|
|
|
560 |
rose_tokens: str, rose_json: str,
|
561 |
show_thinking: bool = False,
|
562 |
reasoning_effort: str = "high") -> str:
|
563 |
+
"""Generate response with CoT handling."""
|
|
|
|
|
564 |
try:
|
565 |
+
# Build messages
|
566 |
messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
|
567 |
|
|
|
568 |
if history:
|
569 |
for turn in history:
|
570 |
if isinstance(turn, (list, tuple)) and len(turn) >= 2:
|
|
|
574 |
if assistant_msg:
|
575 |
messages.append({"role": "assistant", "content": str(assistant_msg)})
|
576 |
|
|
|
577 |
messages.append({"role": "user", "content": str(message)})
|
578 |
|
579 |
+
# Create prompt
|
580 |
if HARMONY_AVAILABLE:
|
581 |
+
prompt = create_harmony_prompt(messages, reasoning_effort)
|
582 |
else:
|
|
|
583 |
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
584 |
|
585 |
+
# Build Rose map
|
586 |
rose_map: Optional[Dict[str, float]] = None
|
587 |
if rose_enable:
|
588 |
rose_map = {}
|
|
|
609 |
if not rose_map:
|
610 |
rose_map = None
|
611 |
|
612 |
+
# Generate
|
613 |
channels = zerogpu_generate(
|
614 |
prompt,
|
615 |
{
|
|
|
618 |
"top_p": float(top_p),
|
619 |
"top_k": int(top_k) if top_k > 0 else None,
|
620 |
"max_new_tokens": int(max_new_tokens),
|
621 |
+
"repetition_penalty": 1.1,
|
622 |
+
"no_repeat_ngram_size": 6,
|
623 |
},
|
624 |
rose_map,
|
625 |
float(rose_alpha),
|
|
|
629 |
|
630 |
# Format response
|
631 |
if show_thinking:
|
|
|
632 |
response = "## Chain of Thought:\n\n"
|
633 |
for channel, content in channels.items():
|
634 |
if channel != "final" and content:
|
|
|
636 |
response += f"### Final Response:\n{channels.get('final', 'No final response generated')}"
|
637 |
return response
|
638 |
else:
|
|
|
639 |
return channels.get("final", "No final response generated")
|
640 |
|
641 |
except Exception as e:
|
642 |
+
import traceback
|
643 |
+
return f"[Error] {type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
|
644 |
|
645 |
# -----------------------
|
646 |
# UI
|
647 |
# -----------------------
|
648 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
649 |
gr.Markdown(
|
650 |
+
f"""
|
651 |
# Mirel – Harmony Chain-of-Thought Inference
|
652 |
|
653 |
+
**Model**: {MODEL_ID} {'(MX Format)' if USE_MX_FORMAT else ''}
|
654 |
+
**Adapter**: {ADAPTER_ID or 'None'}
|
655 |
+
**Status**: {'✅ Harmony Available' if HARMONY_AVAILABLE else '⚠️ Harmony Not Installed'}
|
656 |
|
657 |
+
The model uses internal thinking channels before providing final responses.
|
658 |
"""
|
659 |
)
|
660 |
|
|
|
704 |
value=""
|
705 |
)
|
706 |
|
707 |
+
# Chat interface
|
708 |
chat = gr.ChatInterface(
|
709 |
fn=generate_response,
|
710 |
type="messages",
|
|
|
714 |
rose_tokens, rose_json, show_thinking, reasoning_effort
|
715 |
],
|
716 |
title="Chat with Mirel",
|
717 |
+
description="Chain-of-thought model with MX format support",
|
718 |
examples=[
|
719 |
["Hello! Can you introduce yourself?"],
|
720 |
["What is the capital of France?"],
|
|
|
728 |
"""
|
729 |
---
|
730 |
### Configuration:
|
731 |
+
- **MX Format**: Automatically detected for GPT-OSS models
|
732 |
+
- **LoRA Support**: fp32 LoRA adapters are converted for MX compatibility
|
733 |
+
- **Merge Adapter**: Set `MERGE_ADAPTER=1` to merge LoRA into base model
|
734 |
- **Auth**: Set `HF_TOKEN` in Space secrets for private model access
|
|
|
|
|
|
|
735 |
"""
|
736 |
)
|
737 |
|
|
|
740 |
server_name="0.0.0.0",
|
741 |
server_port=7860,
|
742 |
share=False
|
743 |
+
)
|