""" Mirel Harmony Inference – HF Space (Gradio) ZeroGPU-ready, Harmony formatting, MX format support for GPT-OSS-20B Proper LoRA adapter loading and conversion for MX compatibility Single file: app.py """ from __future__ import annotations import os, gc, json, threading, torch, warnings from dataclasses import dataclass from typing import List, Dict, Optional, Any, Union from datetime import datetime import gradio as gr import spaces # required for ZeroGPU from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig import numpy as np # Suppress warnings about MX format warnings.filterwarnings("ignore", message=".*microscaling.*") warnings.filterwarnings("ignore", message=".*mx.*") # Import Harmony components try: from openai_harmony import ( Author, Conversation, HarmonyEncodingName, Message, Role, SystemContent, DeveloperContent, load_harmony_encoding, ReasoningEffort ) HARMONY_AVAILABLE = True except ImportError: print("[WARNING] openai_harmony not installed. Install with: pip install openai-harmony") HARMONY_AVAILABLE = False # ----------------------- # Config & runtime modes # ----------------------- # MX format uses special dtypes - we need to handle this properly MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b") ADAPTER_ID = os.getenv("ADAPTER_ID") or None ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER") or None ATTN_IMPL = os.getenv("ATTN_IMPL", "eager") SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.") MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256")) ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1" # For GPT-OSS models, we need specific handling IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower() USE_MX_FORMAT = os.getenv("USE_MX_FORMAT", "1" if IS_GPT_OSS else "0") == "1" # Harmony channels for CoT REQUIRED_CHANNELS = ["analysis", "commentary", "final"] # HF Auth HF_TOKEN: Optional[str] = ( os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_ACCESS_TOKEN") ) def _hf_login() -> None: """Login to HF Hub using common env secret names.""" if HF_TOKEN: try: from huggingface_hub import login, whoami login(token=HF_TOKEN, add_to_git_credential=True) try: who = whoami(token=HF_TOKEN) print(f"[HF Auth] Logged in as: {who.get('name') or who.get('fullname') or who.get('id', 'unknown')}") except Exception: print("[HF Auth] Login successful but couldn't get user info") except Exception as e: print(f"[HF Auth] Login failed: {e}") else: print("[HF Auth] No token found in environment variables") # Login before loading any models _hf_login() os.environ["TOKENIZERS_PARALLELISM"] = "false" # Load Harmony encoding if available if HARMONY_AVAILABLE: harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) else: harmony_encoding = None # Stop tokens per Harmony spec: <|return|> (200002), <|call|> (200012) HARMONY_STOP_IDS = harmony_encoding.stop_tokens_for_assistant_actions() if HARMONY_AVAILABLE else [] # Tokenizer is lightweight; load once try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) print(f"[Model] Successfully loaded tokenizer from {MODEL_ID}") except Exception as e: print(f"[Model] Failed to load tokenizer: {e}") raise # ----------------------- # PEFT and MX Format Support # ----------------------- try: from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model _HAS_PEFT = True except Exception: _HAS_PEFT = False print("[Warning] PEFT not available. Install with: pip install peft") # Try to import microscaling support if available try: import msamp _HAS_MSAMP = True print("[Info] Microsoft AMP (msamp) available for MX format support") except ImportError: _HAS_MSAMP = False print("[Info] msamp not available - using fallback MX handling") # ----------------------- # MX Format Conversion # ----------------------- def convert_fp32_lora_to_mx_compatible(lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Convert fp32 LoRA weights to be compatible with MX format base model. MX models expect specific dtype handling. """ converted = {} for key, tensor in lora_state_dict.items(): if tensor is None: converted[key] = tensor continue # LoRA weights (lora_A, lora_B) need special handling if 'lora_' in key: # For MX compatibility, we keep weights in fp32 but ensure proper scaling # MX format internally handles quantization, we just need clean fp32 inputs if tensor.dtype != torch.float32: tensor = tensor.to(torch.float32) # Ensure weights are in reasonable range for MX quantization # MX format works best with weights in [-1, 1] range if 'lora_A' in key: # Input projection - initialize with small values std = 1.0 / torch.sqrt(torch.tensor(tensor.shape[1], dtype=torch.float32)) if tensor.std() > std * 10: # If weights are too large print(f"[MX Convert] Scaling down {key} from std={tensor.std():.4f} to {std:.4f}") tensor = tensor * (std / tensor.std()) elif 'lora_B' in key: # Output projection - should be near zero initially if tensor.abs().max() > 0.1: print(f"[MX Convert] Scaling down {key} max={tensor.abs().max():.4f}") tensor = tensor * 0.01 converted[key] = tensor else: # Non-LoRA weights (like embeddings) stay as-is converted[key] = tensor return converted def prepare_model_for_mx_lora(model, adapter_path: str): """ Prepare and attach LoRA adapter to MX format model. Handles the special requirements of GPT-OSS MX models. """ if not _HAS_PEFT: raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft") print(f"[LoRA] Loading adapter from {adapter_path}") # Load the LoRA config peft_config = PeftConfig.from_pretrained(adapter_path, token=HF_TOKEN) # Load the LoRA weights from safetensors.torch import load_file import os.path as osp adapter_weights_path = osp.join(adapter_path, "adapter_model.safetensors") if not osp.exists(adapter_weights_path): adapter_weights_path = osp.join(adapter_path, "adapter_model.bin") if osp.exists(adapter_weights_path): adapter_weights = torch.load(adapter_weights_path, map_location="cpu") else: raise FileNotFoundError(f"No adapter weights found at {adapter_path}") else: adapter_weights = load_file(adapter_weights_path) # Convert weights for MX compatibility print("[LoRA] Converting fp32 weights for MX format compatibility...") adapter_weights = convert_fp32_lora_to_mx_compatible(adapter_weights) # Create PEFT model with special handling for MX print("[LoRA] Attaching LoRA to base model...") # For MX models, we need to be careful about dtype # The base model uses MX format internally, but the interface should be fp32 model = PeftModel.from_pretrained( model, adapter_path, is_trainable=False, token=HF_TOKEN, # Don't specify torch_dtype here - let it match the base model ) # Manually update the adapter weights with our converted versions model.load_state_dict(adapter_weights, strict=False) print("[LoRA] Successfully attached LoRA adapter with MX compatibility") return model # ----------------------- # Model loading with MX support # ----------------------- def _build_model_kwargs(device_map: Optional[str]) -> Dict[str, Any]: """Build kwargs for model loading with MX format support.""" kw: Dict[str, Any] = dict( device_map=device_map, trust_remote_code=True, low_cpu_mem_usage=True, token=HF_TOKEN, ) if IS_GPT_OSS and USE_MX_FORMAT: # GPT-OSS models use MX format # Don't specify torch_dtype - let the model use its native MX format print("[Model] Using MX format for GPT-OSS model") kw.update({ "attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager", # MX models handle their own dtype internally # Don't force a dtype here }) else: # Non-MX models kw.update({ "torch_dtype": torch.float16, # Use fp16 for non-MX models "attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager", }) return kw def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM: """Load model with proper MX format handling.""" print(f"[Model] Loading base model from {MODEL_ID}...") # Load config first to check for MX format config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) # Check if this is an MX model is_mx_model = ( IS_GPT_OSS or hasattr(config, 'quantization_config') and 'mx' in str(config.quantization_config).lower() or hasattr(config, 'torch_dtype') and 'mx' in str(config.torch_dtype).lower() ) if is_mx_model: print("[Model] Detected MX format model - using special loading") # For MX models, we need special handling # The model internally uses MX quantization model = AutoModelForCausalLM.from_pretrained( MODEL_ID, config=config, trust_remote_code=True, device_map=device_map, low_cpu_mem_usage=True, token=HF_TOKEN, # Let the model handle its own dtype attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager", ) # Verify the model loaded correctly print(f"[Model] Model dtype: {next(model.parameters()).dtype}") print(f"[Model] Model device: {next(model.parameters()).device}") else: # Standard model loading model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **_build_model_kwargs(device_map)) # Load and attach LoRA adapter if specified if ADAPTER_ID: try: if is_mx_model: # Use special MX-compatible LoRA loading model = prepare_model_for_mx_lora(model, ADAPTER_ID) else: # Standard PEFT loading for non-MX models if not _HAS_PEFT: raise RuntimeError("PEFT is required when ADAPTER_ID is set.") print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...") model = PeftModel.from_pretrained( model, ADAPTER_ID, is_trainable=False, token=HF_TOKEN ) print("[Model] Successfully loaded with LoRA adapter") # Optionally merge adapter for better performance merge_adapter = os.getenv("MERGE_ADAPTER", "0") == "1" if merge_adapter and hasattr(model, 'merge_and_unload'): print("[Model] Merging adapter into base model...") model = model.merge_and_unload() print("[Model] Adapter merged successfully") except Exception as e: print(f"[Error] Failed to load adapter: {e}") print("[Warning] Continuing with base model only") model.eval() # Ensure proper config if getattr(model.config, "pad_token_id", None) is None: model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id model.config.use_cache = True print(f"[Model] Model loaded successfully - Type: {'MX Format' if is_mx_model else 'Standard'}") return model # ----------------------- # Harmony formatting # ----------------------- def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any: """Build a Harmony-formatted prompt.""" if HARMONY_AVAILABLE and harmony_encoding is not None: effort_map = {"low": ReasoningEffort.LOW, "medium": ReasoningEffort.MEDIUM, "high": ReasoningEffort.HIGH} effort = effort_map.get(str(reasoning_effort).lower(), ReasoningEffort.HIGH) system_content = ( SystemContent.new() .with_model_identity("You are ChatGPT, a large language model trained by OpenAI.") .with_reasoning_effort(effort) .with_conversation_start_date(datetime.now().strftime("%Y-%m-%d")) .with_knowledge_cutoff("2024-06") .with_required_channels(REQUIRED_CHANNELS) ) sys_text = SYSTEM_DEF rest: List[Dict[str, str]] = messages or [] if rest and rest[0].get("role") == "system": sys_text = rest[0].get("content") or SYSTEM_DEF rest = rest[1:] harmony_messages = [Message.from_role_and_content(Role.SYSTEM, system_content)] dev = DeveloperContent.new().with_instructions(sys_text) harmony_messages.append(Message.from_role_and_content(Role.DEVELOPER, dev)) for m in rest: role = m.get("role"); content = m.get("content", "") if role == "user": harmony_messages.append(Message.from_role_and_content(Role.USER, content)) elif role == "assistant": harmony_messages.append( Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final") ) convo = Conversation.from_messages(harmony_messages) return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT) # Fallback: tokenizer chat template if not messages or messages[0].get("role") != "system": messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or []) return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) def parse_harmony_response(tokens: List[int]) -> Dict[str, str]: """Parse response tokens using Harmony format to extract channels.""" if not HARMONY_AVAILABLE: text = tokenizer.decode(tokens, skip_special_tokens=False) return {"final": extract_final_channel_fallback(text), "raw": text} parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT) channels = {} for msg in parsed_messages: channel = msg.channel if hasattr(msg, 'channel') else "final" if channel not in channels: channels[channel] = "" channels[channel] += "".join([getattr(part, "text", str(part)) for part in (msg.content if isinstance(msg.content, list) else [msg.content])]) if "final" not in channels: channels["final"] = " ".join(channels.values()) return channels def extract_final_channel_fallback(text: str) -> str: """Extract the channel from decoded Harmony text.""" try: chunks: Dict[str, str] = {} pieces = text.split("<|channel|>") for seg in pieces[1:]: name_end = seg.find("<|message|>") if name_end <= 0: continue ch = seg[:name_end].strip() body_start = name_end + len("<|message|>") next_pos = len(seg) for delim in ("<|channel|>", "<|end|>", "<|return|>"): p = seg.find(delim, body_start) if p != -1: next_pos = min(next_pos, p) body = seg[body_start:next_pos] chunks[ch] = chunks.get(ch, "") + body final_txt = (chunks.get("final", "").strip()) if final_txt: return final_txt if "<|channel|>final<|message|>" in text: tail = text.split("<|channel|>final<|message|>")[-1] for delim in ("<|return|>", "<|end|>", "<|channel|>"): idx = tail.find(delim) if idx != -1: tail = tail[:idx] break return tail.strip() except Exception: pass return text.strip() # ----------------------- # Rose guidance # ----------------------- def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor: """Create vocab bias from {token: weight}.""" vocab_size = len(tokenizer) bias = torch.zeros(vocab_size, dtype=torch.float32) for tok, w in mapping.items(): if tok is None: continue tid = tokenizer.convert_tokens_to_ids(tok) if isinstance(tid, list): for t in tid: if isinstance(t, int) and t >= 0: bias[t] += float(w) / max(1, len(tid)) elif isinstance(tid, int) and tid >= 0: bias[tid] += float(w) return bias class RoseGuidedLogits(torch.nn.Module): def __init__(self, bias_vec: torch.Tensor, alpha: float = 1.0): super().__init__() self.bias_vec = bias_vec self.alpha = float(alpha) def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: return scores + self.alpha * self.bias_vec.to(scores.device) # ----------------------- # Generation # ----------------------- @spaces.GPU(duration=120) def zerogpu_generate(full_prompt, gen_kwargs: Dict[str, Any], rose_map: Optional[Dict[str, float]], rose_alpha: float, rose_score: Optional[float], seed: Optional[int]) -> Dict[str, str]: """Run inference on GPU with MX format support.""" try: if seed is not None: torch.manual_seed(int(seed)) # Load model with MX support model = _load_model_on("auto") # Setup logits processor for Rose guidance logits_processor = None if rose_map: bias = build_bias_from_tokens(tokenizer, rose_map).to(next(model.parameters()).device) eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0) logits_processor = [RoseGuidedLogits(bias, eff_alpha)] # Prepare inputs device = next(model.parameters()).device if HARMONY_AVAILABLE and isinstance(full_prompt, list): input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device) attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) inputs = {"input_ids": input_ids, "attention_mask": attention_mask} prompt_len = input_ids.shape[1] else: enc = tokenizer(full_prompt, return_tensors="pt") inputs = enc.to(device) prompt_len = int(inputs["input_ids"].shape[1]) if "attention_mask" not in inputs: inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device) # Generate eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id out_ids = model.generate( **inputs, do_sample=bool(gen_kwargs.get("do_sample", True)), temperature=float(gen_kwargs.get("temperature", 0.7)), top_p=float(gen_kwargs.get("top_p", 0.9)), top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None), max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)), pad_token_id=model.config.pad_token_id, eos_token_id=eos_ids, logits_processor=logits_processor, repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)), no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)), min_new_tokens=1, ) # Extract generated tokens out_list = out_ids[0].tolist() gen_ids = out_list[prompt_len:] # Truncate at stop tokens if HARMONY_AVAILABLE: for sid in HARMONY_STOP_IDS: if sid in gen_ids: gen_ids = gen_ids[:gen_ids.index(sid)] break # Parse response if HARMONY_AVAILABLE: try: channels = parse_harmony_response(gen_ids) except Exception: decoded = tokenizer.decode(gen_ids, skip_special_tokens=False) channels = { "final": extract_final_channel_fallback(decoded), "raw": decoded } else: decoded = tokenizer.decode(gen_ids, skip_special_tokens=False) channels = { "final": extract_final_channel_fallback(decoded), "raw": decoded } return channels except Exception as e: import traceback error_trace = traceback.format_exc() print(f"[Error] Generation failed:\n{error_trace}") return {"final": f"[Error] {type(e).__name__}: {str(e)}", "raw": error_trace} finally: # Cleanup try: del model except: pass gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # ----------------------- # Gradio handlers # ----------------------- def generate_response(message: str, history: List[List[str]], system_prompt: str, temperature: float, top_p: float, top_k: int, max_new_tokens: int, do_sample: bool, seed: Optional[int], rose_enable: bool, rose_alpha: float, rose_score: Optional[float], rose_tokens: str, rose_json: str, show_thinking: bool = False, reasoning_effort: str = "high") -> str: """Generate response with CoT handling.""" try: # Build messages messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}] if history: for turn in history: if isinstance(turn, (list, tuple)) and len(turn) >= 2: user_msg, assistant_msg = turn[0], turn[1] if user_msg: messages.append({"role": "user", "content": str(user_msg)}) if assistant_msg: messages.append({"role": "assistant", "content": str(assistant_msg)}) messages.append({"role": "user", "content": str(message)}) # Create prompt if HARMONY_AVAILABLE: prompt = create_harmony_prompt(messages, reasoning_effort) else: prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) # Build Rose map rose_map: Optional[Dict[str, float]] = None if rose_enable: rose_map = {} tok_str = (rose_tokens or "").strip() if tok_str: for p in [p.strip() for p in tok_str.split(",") if p.strip()]: if ":" in p: k, v = p.split(":", 1) try: rose_map[k.strip()] = float(v) except: pass if rose_json: try: j = json.loads(rose_json) if isinstance(j, dict): for k, v in j.items(): try: rose_map[str(k)] = float(v) except: pass except: pass if not rose_map: rose_map = None # Generate channels = zerogpu_generate( prompt, { "do_sample": bool(do_sample), "temperature": float(temperature), "top_p": float(top_p), "top_k": int(top_k) if top_k > 0 else None, "max_new_tokens": int(max_new_tokens), "repetition_penalty": 1.1, "no_repeat_ngram_size": 6, }, rose_map, float(rose_alpha), float(rose_score) if rose_score is not None else None, int(seed) if seed is not None else None, ) # Format response if show_thinking: response = "## Chain of Thought:\n\n" for channel, content in channels.items(): if channel != "final" and content: response += f"### {channel.capitalize()} Channel:\n{content}\n\n" response += f"### Final Response:\n{channels.get('final', 'No final response generated')}" return response else: return channels.get("final", "No final response generated") except Exception as e: import traceback return f"[Error] {type(e).__name__}: {str(e)}\n{traceback.format_exc()}" # ----------------------- # UI # ----------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( f""" # Mirel – Harmony Chain-of-Thought Inference **Model**: {MODEL_ID} {'(MX Format)' if USE_MX_FORMAT else ''} **Adapter**: {ADAPTER_ID or 'None'} **Status**: {'✅ Harmony Available' if HARMONY_AVAILABLE else '⚠️ Harmony Not Installed'} The model uses internal thinking channels before providing final responses. """ ) with gr.Row(): system_prompt = gr.Textbox( label="System Prompt", value=SYSTEM_DEF, lines=2 ) with gr.Accordion("Generation Settings", open=False): with gr.Row(): temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p") top_k = gr.Slider(0, 200, value=0, step=1, label="Top-k (0=disabled)") with gr.Row(): max_new = gr.Slider(16, 4096, value=MAX_DEF, step=16, label="Max new tokens") do_sample = gr.Checkbox(value=True, label="Do sample") seed = gr.Number(value=None, label="Seed (optional)", precision=0) with gr.Row(): reasoning_effort = gr.Radio( choices=["low", "medium", "high"], value="high", label="Reasoning Effort", info="How much thinking the model should do" ) show_thinking = gr.Checkbox( value=False, label="Show thinking channels", info="Display all internal reasoning channels" ) with gr.Accordion("Rose Guidance (Optional)", open=False): gr.Markdown("Fine-tune generation with token biases") with gr.Row(): rose_enable = gr.Checkbox(value=False, label="Enable Rose bias") rose_alpha = gr.Slider(0.0, 5.0, value=1.0, step=0.05, label="Alpha (strength)") rose_score = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Score multiplier") rose_tokens = gr.Textbox( label="Token:weight pairs", placeholder="example:1.5, test:-0.5", value="" ) rose_json = gr.Textbox( label="JSON weights", placeholder='{"token": 1.0, "another": -0.5}', value="" ) # Chat interface chat = gr.ChatInterface( fn=generate_response, type="messages", additional_inputs=[ system_prompt, temperature, top_p, top_k, max_new, do_sample, seed, rose_enable, rose_alpha, rose_score, rose_tokens, rose_json, show_thinking, reasoning_effort ], title="Chat with Mirel", description="Chain-of-thought model with MX format support", examples=[ ["Hello! Can you introduce yourself?"], ["What is the capital of France?"], ["Explain quantum computing in simple terms"], ["Solve: If a train travels 120 miles in 2 hours, what is its average speed?"], ], cache_examples=False, ) gr.Markdown( """ --- ### Configuration: - **MX Format**: Automatically detected for GPT-OSS models - **LoRA Support**: fp32 LoRA adapters are converted for MX compatibility - **Merge Adapter**: Set `MERGE_ADAPTER=1` to merge LoRA into base model - **Auth**: Set `HF_TOKEN` in Space secrets for private model access """ ) if __name__ == "__main__": demo.queue(max_size=8 if ZEROGPU else 32).launch( server_name="0.0.0.0", server_port=7860, share=False )