Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Mirel Harmony Inference β HF Space (Gradio) | |
| ZeroGPU-ready, Harmony formatting, MX format support for GPT-OSS-20B | |
| Proper LoRA adapter loading and conversion for MX compatibility | |
| Single file: app.py | |
| Requirements: | |
| huggingface_hub>=0.34.0 | |
| transformers>=4.55.0 | |
| accelerate>=0.33.0 | |
| peft>=0.11.0 | |
| torch>=2.4.0 | |
| bitsandbytes>=0.43.1 | |
| openai-harmony | |
| gradio>=5.42.0 | |
| triton>=3.4.0 | |
| git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels | |
| """ | |
| # ===== SETUP: Ensure triton_kernels is installed for MX format ===== | |
| import subprocess | |
| import sys | |
| def ensure_triton_kernels(): | |
| """Ensure triton_kernels is installed for MX format support on H200.""" | |
| try: | |
| import triton_kernels | |
| print("β triton_kernels already installed - MX format supported") | |
| return True | |
| except ImportError: | |
| print("Installing triton_kernels for MX format support...") | |
| try: | |
| subprocess.check_call([ | |
| sys.executable, "-m", "pip", "install", | |
| "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels" | |
| ]) | |
| print("β triton_kernels installed successfully") | |
| # Force reimport | |
| import importlib | |
| import site | |
| importlib.reload(site) | |
| return True | |
| except subprocess.CalledProcessError as e: | |
| print(f"β Failed to install triton_kernels: {e}") | |
| print("ERROR: MX format will NOT work properly without triton_kernels!") | |
| return False | |
| # Install triton_kernels before other imports | |
| _TRITON_INSTALL_SUCCESS = ensure_triton_kernels() | |
| # ===== MAIN IMPORTS ===== | |
| from __future__ import annotations | |
| import os, gc, json, torch, warnings, traceback | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Optional, Any, Union | |
| from datetime import datetime | |
| import gradio as gr | |
| import spaces # required for ZeroGPU | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
| import numpy as np | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore", message=".*microscaling.*") | |
| warnings.filterwarnings("ignore", message=".*mx.*") | |
| # Import Harmony components | |
| try: | |
| from openai_harmony import ( | |
| Author, | |
| Conversation, | |
| HarmonyEncodingName, | |
| Message, | |
| Role, | |
| SystemContent, | |
| DeveloperContent, | |
| load_harmony_encoding, | |
| ReasoningEffort | |
| ) | |
| HARMONY_AVAILABLE = True | |
| print("β OpenAI Harmony loaded successfully") | |
| except ImportError: | |
| print("β openai_harmony not installed. Install with: pip install openai-harmony") | |
| HARMONY_AVAILABLE = False | |
| # Import PEFT for LoRA support | |
| try: | |
| from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model | |
| _HAS_PEFT = True | |
| print("β PEFT loaded successfully") | |
| except Exception: | |
| _HAS_PEFT = False | |
| print("β PEFT not available. Install with: pip install peft") | |
| # Check for triton_kernels (required for MX format) | |
| try: | |
| import triton_kernels | |
| _HAS_TRITON_KERNELS = True | |
| print("β triton_kernels loaded - MX format enabled") | |
| except ImportError: | |
| _HAS_TRITON_KERNELS = False | |
| print("β triton_kernels not available - MX format disabled!") | |
| # ===== CONFIGURATION ===== | |
| MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b") | |
| ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b") | |
| ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516") | |
| ATTN_IMPL = os.getenv("ATTN_IMPL", "eager") | |
| SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.") | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512")) | |
| ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "1")) == "1" | |
| MERGE_ADAPTER = os.getenv("MERGE_ADAPTER", "0") == "1" | |
| # Detect if using GPT-OSS model | |
| IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower() | |
| USE_MX_FORMAT = IS_GPT_OSS and _HAS_TRITON_KERNELS | |
| # Harmony channels for chain-of-thought | |
| REQUIRED_CHANNELS = ["analysis", "commentary", "final"] | |
| # HF Authentication | |
| HF_TOKEN = ( | |
| os.getenv("HF_TOKEN") | |
| or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| or os.getenv("HF_ACCESS_TOKEN") | |
| ) | |
| def _hf_login(): | |
| """Login to HuggingFace Hub.""" | |
| if HF_TOKEN: | |
| try: | |
| from huggingface_hub import login, whoami | |
| login(token=HF_TOKEN, add_to_git_credential=True) | |
| try: | |
| user = whoami(token=HF_TOKEN) | |
| print(f"β Logged in as: {user.get('name', user.get('id', 'unknown'))}") | |
| except: | |
| print("β HF login successful") | |
| except Exception as e: | |
| print(f"β HF login failed: {e}") | |
| else: | |
| print("β No HF_TOKEN found in environment") | |
| # Login before loading models | |
| _hf_login() | |
| # Disable tokenizer parallelism warning | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # ===== LOAD TOKENIZER ===== | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) | |
| print(f"β Tokenizer loaded from {MODEL_ID}") | |
| except Exception as e: | |
| print(f"β Failed to load tokenizer: {e}") | |
| raise | |
| # ===== HARMONY SETUP ===== | |
| if HARMONY_AVAILABLE: | |
| harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) | |
| HARMONY_STOP_IDS = harmony_encoding.stop_tokens_for_assistant_actions() | |
| else: | |
| harmony_encoding = None | |
| HARMONY_STOP_IDS = [] | |
| # ===== MODEL LOADING WITH MX FORMAT SUPPORT ===== | |
| def detect_mx_format(model) -> bool: | |
| """Check if model is using native MX format.""" | |
| if not hasattr(model, 'model') or not hasattr(model.model, 'layers'): | |
| return False | |
| try: | |
| first_layer = model.model.layers[0] | |
| if hasattr(first_layer, 'block_sparse_moe'): | |
| expert = first_layer.block_sparse_moe.experts[0] | |
| if hasattr(expert, 'w1'): | |
| # Check for MX format scale tensors | |
| return hasattr(expert.w1, 'scales') | |
| except: | |
| pass | |
| return False | |
| def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM: | |
| """Load the base model with proper MX format handling.""" | |
| print(f"\n{'='*50}") | |
| print(f"Loading model: {MODEL_ID}") | |
| print(f"MX Format Available: {_HAS_TRITON_KERNELS}") | |
| print(f"{'='*50}\n") | |
| # Load config to check model type | |
| config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) | |
| # Build loading kwargs | |
| load_kwargs = { | |
| "trust_remote_code": True, | |
| "device_map": device_map, | |
| "low_cpu_mem_usage": True, | |
| "token": HF_TOKEN, | |
| "attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager", | |
| } | |
| if IS_GPT_OSS: | |
| if _HAS_TRITON_KERNELS: | |
| print("β Loading with native MX format support") | |
| load_kwargs["torch_dtype"] = "auto" # Let model use native MX | |
| else: | |
| print("β No triton_kernels - falling back to bf16 (dequantized)") | |
| print(" This will likely cause LoRA compatibility issues!") | |
| load_kwargs["torch_dtype"] = torch.bfloat16 | |
| else: | |
| # Non-GPT-OSS models | |
| load_kwargs["torch_dtype"] = torch.bfloat16 | |
| # Load the model | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs) | |
| # Verify format | |
| print(f"Model loaded - dtype: {next(model.parameters()).dtype}") | |
| if IS_GPT_OSS: | |
| is_mx = detect_mx_format(model) | |
| if is_mx: | |
| print("β Confirmed: Using native MX format") | |
| else: | |
| print("β Model dequantized to bf16 - LoRA may fail") | |
| # Set model config | |
| if getattr(model.config, "pad_token_id", None) is None: | |
| model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id | |
| model.config.use_cache = True | |
| return model | |
| def load_lora_adapter(model, adapter_id: str, subfolder: Optional[str] = None): | |
| """Load and attach LoRA adapter with MX format handling.""" | |
| if not _HAS_PEFT: | |
| raise RuntimeError("PEFT is required for LoRA adapters") | |
| print(f"\n{'='*50}") | |
| print(f"Loading LoRA: {adapter_id}") | |
| if subfolder: | |
| print(f"Subfolder: {subfolder}") | |
| print(f"{'='*50}\n") | |
| # Check if model is using MX format | |
| is_mx = detect_mx_format(model) if IS_GPT_OSS else False | |
| # Prepare kwargs for PEFT | |
| peft_kwargs = {"token": HF_TOKEN, "is_trainable": False} | |
| if subfolder: | |
| peft_kwargs["subfolder"] = subfolder | |
| try: | |
| # Load adapter configuration | |
| peft_config = PeftConfig.from_pretrained(adapter_id, **peft_kwargs) | |
| print(f"LoRA config: r={peft_config.r}, alpha={peft_config.lora_alpha}") | |
| # Load the adapter | |
| model = PeftModel.from_pretrained(model, adapter_id, **peft_kwargs) | |
| if not is_mx and IS_GPT_OSS: | |
| print("β WARNING: Model is bf16 but LoRA was likely trained on MX format") | |
| print(" Reducing LoRA influence to 10% to prevent corruption") | |
| # Scale down LoRA weights | |
| for name, param in model.named_parameters(): | |
| if 'lora_' in name: | |
| param.data *= 0.1 | |
| print("β LoRA adapter loaded successfully") | |
| # Optionally merge adapter | |
| if MERGE_ADAPTER and hasattr(model, 'merge_and_unload'): | |
| print("Merging adapter into base model...") | |
| model = model.merge_and_unload() | |
| print("β Adapter merged") | |
| return model | |
| except Exception as e: | |
| print(f"β Failed to load LoRA: {e}") | |
| print("Continuing with base model only") | |
| return model | |
| # ===== HARMONY FORMATTING ===== | |
| def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high"): | |
| """Create Harmony-formatted prompt.""" | |
| if not HARMONY_AVAILABLE or not harmony_encoding: | |
| # Fallback to chat template | |
| if messages and messages[0].get("role") != "system": | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages | |
| return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
| # Map reasoning effort | |
| effort_map = { | |
| "low": ReasoningEffort.LOW, | |
| "medium": ReasoningEffort.MEDIUM, | |
| "high": ReasoningEffort.HIGH | |
| } | |
| effort = effort_map.get(reasoning_effort.lower(), ReasoningEffort.HIGH) | |
| # Build Harmony conversation | |
| system_content = ( | |
| SystemContent.new() | |
| .with_model_identity("You are ChatGPT, a large language model trained by OpenAI.") | |
| .with_reasoning_effort(effort) | |
| .with_conversation_start_date(datetime.now().strftime("%Y-%m-%d")) | |
| .with_knowledge_cutoff("2024-06") | |
| .with_required_channels(REQUIRED_CHANNELS) | |
| ) | |
| # Extract system prompt | |
| sys_text = SYSTEM_PROMPT | |
| rest = messages or [] | |
| if rest and rest[0].get("role") == "system": | |
| sys_text = rest[0].get("content", SYSTEM_PROMPT) | |
| rest = rest[1:] | |
| # Build messages | |
| harmony_messages = [ | |
| Message.from_role_and_content(Role.SYSTEM, system_content), | |
| Message.from_role_and_content( | |
| Role.DEVELOPER, | |
| DeveloperContent.new().with_instructions(sys_text) | |
| ) | |
| ] | |
| for msg in rest: | |
| role = msg.get("role") | |
| content = msg.get("content", "") | |
| if role == "user": | |
| harmony_messages.append(Message.from_role_and_content(Role.USER, content)) | |
| elif role == "assistant": | |
| harmony_messages.append( | |
| Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final") | |
| ) | |
| # Render to token IDs | |
| convo = Conversation.from_messages(harmony_messages) | |
| return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT) | |
| def parse_harmony_response(tokens: List[int]) -> Dict[str, str]: | |
| """Parse Harmony response tokens into channels.""" | |
| if not HARMONY_AVAILABLE or not harmony_encoding: | |
| text = tokenizer.decode(tokens, skip_special_tokens=False) | |
| return {"final": extract_final_channel(text), "raw": text} | |
| try: | |
| # Parse using Harmony | |
| parsed = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT) | |
| channels = {} | |
| for msg in parsed: | |
| channel = getattr(msg, 'channel', 'final') | |
| if channel not in channels: | |
| channels[channel] = "" | |
| # Extract text content | |
| content = msg.content | |
| if isinstance(content, list): | |
| text = "".join([getattr(part, "text", str(part)) for part in content]) | |
| else: | |
| text = getattr(content, "text", str(content)) | |
| channels[channel] += text | |
| # Ensure final channel exists | |
| if "final" not in channels: | |
| channels["final"] = " ".join(channels.values()) | |
| return channels | |
| except Exception as e: | |
| print(f"Harmony parsing failed: {e}") | |
| text = tokenizer.decode(tokens, skip_special_tokens=False) | |
| return {"final": extract_final_channel(text), "raw": text} | |
| def extract_final_channel(text: str) -> str: | |
| """Extract final channel from raw text.""" | |
| # Look for <|channel|>final<|message|> | |
| if "<|channel|>final<|message|>" in text: | |
| parts = text.split("<|channel|>final<|message|>") | |
| if len(parts) > 1: | |
| final = parts[-1] | |
| # Truncate at next marker | |
| for marker in ["<|channel|>", "<|end|>", "<|return|>"]: | |
| if marker in final: | |
| final = final.split(marker)[0] | |
| return final.strip() | |
| # Fallback: return cleaned text | |
| for marker in ["<|channel|>", "<|message|>", "<|end|>", "<|return|>"]: | |
| text = text.replace(marker, " ") | |
| return text.strip() | |
| # ===== GENERATION ===== | |
| def generate_on_gpu( | |
| prompt, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| max_new_tokens: int, | |
| do_sample: bool, | |
| repetition_penalty: float, | |
| seed: Optional[int] | |
| ) -> Dict[str, str]: | |
| """Run generation on GPU.""" | |
| try: | |
| # Set seed if provided | |
| if seed is not None: | |
| torch.manual_seed(int(seed)) | |
| # Load model | |
| print("\nLoading model for generation...") | |
| model = load_base_model("auto") | |
| # Load LoRA if specified | |
| if ADAPTER_ID: | |
| model = load_lora_adapter(model, ADAPTER_ID, ADAPTER_SUBFOLDER) | |
| model.eval() | |
| # Prepare inputs | |
| device = next(model.parameters()).device | |
| if HARMONY_AVAILABLE and isinstance(prompt, list): | |
| # Harmony returns token IDs | |
| input_ids = torch.tensor([prompt], dtype=torch.long, device=device) | |
| else: | |
| # String prompt | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| input_ids = inputs["input_ids"].to(device) | |
| attention_mask = torch.ones_like(input_ids) | |
| prompt_len = input_ids.shape[1] | |
| # Generate | |
| print("Generating response...") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k if top_k > 0 else None, | |
| do_sample=do_sample, | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=model.config.pad_token_id, | |
| eos_token_id=HARMONY_STOP_IDS if HARMONY_STOP_IDS else tokenizer.eos_token_id, | |
| no_repeat_ngram_size=3, | |
| ) | |
| # Extract generated tokens | |
| gen_tokens = outputs[0][prompt_len:].tolist() | |
| # Truncate at stop tokens | |
| for stop_id in HARMONY_STOP_IDS: | |
| if stop_id in gen_tokens: | |
| gen_tokens = gen_tokens[:gen_tokens.index(stop_id)] | |
| break | |
| # Parse response | |
| channels = parse_harmony_response(gen_tokens) | |
| return channels | |
| except Exception as e: | |
| error_msg = f"Generation failed: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return {"final": f"Error: {str(e)}", "raw": error_msg} | |
| finally: | |
| # Cleanup | |
| if 'model' in locals(): | |
| del model | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # ===== GRADIO INTERFACE ===== | |
| def chat_response( | |
| message: str, | |
| history: List[List[str]], | |
| system_prompt: str, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| max_new_tokens: int, | |
| do_sample: bool, | |
| repetition_penalty: float, | |
| seed: Optional[int], | |
| reasoning_effort: str, | |
| show_thinking: bool | |
| ) -> str: | |
| """Handle chat interaction.""" | |
| try: | |
| # Build conversation | |
| messages = [{"role": "system", "content": system_prompt or SYSTEM_PROMPT}] | |
| # Add history | |
| for turn in history or []: | |
| if isinstance(turn, (list, tuple)) and len(turn) >= 2: | |
| user_msg, assistant_msg = turn[0], turn[1] | |
| if user_msg: | |
| messages.append({"role": "user", "content": str(user_msg)}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": str(assistant_msg)}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| # Create prompt | |
| prompt = create_harmony_prompt(messages, reasoning_effort) | |
| # Generate | |
| channels = generate_on_gpu( | |
| prompt, | |
| temperature, | |
| top_p, | |
| top_k, | |
| max_new_tokens, | |
| do_sample, | |
| repetition_penalty, | |
| seed | |
| ) | |
| # Format response | |
| if show_thinking and len(channels) > 1: | |
| response = "## Chain of Thought:\n\n" | |
| for channel, content in channels.items(): | |
| if channel != "final" and content: | |
| response += f"### {channel.capitalize()}:\n{content}\n\n" | |
| response += f"### Final Response:\n{channels.get('final', 'No response generated')}" | |
| else: | |
| response = channels.get("final", "No response generated") | |
| return response | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # ===== BUILD UI ===== | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Mirel") as demo: | |
| # Header with status | |
| status_mx = "β MX Format" if _HAS_TRITON_KERNELS else "β No MX Support" | |
| status_harmony = "β Harmony" if HARMONY_AVAILABLE else "β No Harmony" | |
| gr.Markdown(f""" | |
| # π€ Mirel β Chain-of-Thought Assistant | |
| **Model:** `{MODEL_ID}` | **Adapter:** `{ADAPTER_ID or 'None'}` | |
| **Status:** {status_mx} | {status_harmony} | {"β ZeroGPU" if ZEROGPU else "CPU Mode"} | |
| {''' | |
| β οΈ **WARNING: MX Format Support Missing!** | |
| Install with: `pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels` | |
| ''' if IS_GPT_OSS and not _HAS_TRITON_KERNELS else ''} | |
| """) | |
| # System prompt | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| value=SYSTEM_PROMPT, | |
| lines=2 | |
| ) | |
| # Settings | |
| with gr.Accordion("βοΈ Generation Settings", open=False): | |
| with gr.Row(): | |
| temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p") | |
| top_k = gr.Slider(0, 200, value=50, step=1, label="Top-k") | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(16, 2048, value=MAX_NEW_TOKENS, step=16, label="Max tokens") | |
| repetition_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label="Repetition penalty") | |
| seed = gr.Number(value=None, label="Seed (optional)", precision=0) | |
| with gr.Row(): | |
| do_sample = gr.Checkbox(value=True, label="Sample") | |
| show_thinking = gr.Checkbox(value=False, label="Show thinking channels") | |
| reasoning_effort = gr.Radio( | |
| ["low", "medium", "high"], | |
| value="high", | |
| label="Reasoning effort" | |
| ) | |
| # Chat interface | |
| chat = gr.ChatInterface( | |
| fn=chat_response, | |
| additional_inputs=[ | |
| system_prompt, | |
| temperature, | |
| top_p, | |
| top_k, | |
| max_new_tokens, | |
| do_sample, | |
| repetition_penalty, | |
| seed, | |
| reasoning_effort, | |
| show_thinking | |
| ], | |
| title=None, | |
| examples=[ | |
| ["Hello! Can you introduce yourself?"], | |
| ["What's the capital of France?"], | |
| ["Explain quantum computing simply"], | |
| ["Write a haiku about coding"], | |
| ], | |
| cache_examples=False, | |
| ) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| π‘ **Tips:** | |
| - Enable "Show thinking channels" to see the model's reasoning process | |
| - Adjust "Reasoning effort" for faster responses (low) or better quality (high) | |
| - The model uses MX format on H200 GPUs for optimal performance | |
| """) | |
| # ===== LAUNCH ===== | |
| if __name__ == "__main__": | |
| print("\n" + "="*60) | |
| print("MIREL READY TO LAUNCH") | |
| print(f"Model: {MODEL_ID}") | |
| print(f"Adapter: {ADAPTER_ID or 'None'}") | |
| print(f"MX Format: {'ENABLED' if _HAS_TRITON_KERNELS else 'DISABLED'}") | |
| print(f"Harmony: {'ENABLED' if HARMONY_AVAILABLE else 'DISABLED'}") | |
| print("="*60 + "\n") | |
| demo.queue(max_size=10).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |