Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Mirel Harmony Inference β HF Space (Gradio) | |
ZeroGPU-ready, Harmony formatting, bf16 mode for GPT-OSS-20B | |
Proper LoRA adapter loading (MX format not available in stable releases) | |
Single file: app.py | |
""" | |
from __future__ import annotations | |
# ===== MAIN IMPORTS ===== | |
import os, gc, json, warnings, traceback | |
import subprocess, sys | |
from dataclasses import dataclass | |
from typing import List, Dict, Optional, Any, Union | |
from datetime import datetime | |
import gradio as gr | |
import spaces # required for ZeroGPU | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
import numpy as np | |
# IMPORTANT: Don't import torch at module level for ZeroGPU | |
# It will be imported inside GPU-decorated functions | |
# Suppress warnings | |
warnings.filterwarnings("ignore", message=".*microscaling.*") | |
warnings.filterwarnings("ignore", message=".*mx.*") | |
# Import Harmony components | |
try: | |
from openai_harmony import ( | |
Author, | |
Conversation, | |
HarmonyEncodingName, | |
Message, | |
Role, | |
SystemContent, | |
DeveloperContent, | |
load_harmony_encoding, | |
ReasoningEffort | |
) | |
HARMONY_AVAILABLE = True | |
print("β OpenAI Harmony loaded successfully") | |
except ImportError: | |
print("β openai_harmony not installed. Install with: pip install openai-harmony") | |
HARMONY_AVAILABLE = False | |
# Import PEFT for LoRA support | |
try: | |
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model | |
_HAS_PEFT = True | |
print("β PEFT loaded successfully") | |
except Exception: | |
_HAS_PEFT = False | |
print("β PEFT not available. Install with: pip install peft") | |
# Note: MX format requires unreleased Triton features | |
# We'll use bf16 mode which works fine for inference | |
_HAS_TRITON_KERNELS = False | |
USE_MX_FORMAT = False | |
print("Note: Using bf16 mode (MX format requires unreleased Triton features)") | |
print("This will work fine but use more memory than native MX format") | |
# ===== CONFIGURATION ===== | |
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b") | |
ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b") | |
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516") | |
ATTN_IMPL = os.getenv("ATTN_IMPL", "eager") | |
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.") | |
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512")) | |
ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "1")) == "1" | |
MERGE_ADAPTER = os.getenv("MERGE_ADAPTER", "0") == "1" | |
# Detect if using GPT-OSS model | |
IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower() | |
USE_MX_FORMAT = IS_GPT_OSS and _HAS_TRITON_KERNELS | |
# Harmony channels for chain-of-thought | |
REQUIRED_CHANNELS = ["analysis", "commentary", "final"] | |
# HF Authentication | |
HF_TOKEN = ( | |
os.getenv("HF_TOKEN") | |
or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
or os.getenv("HF_ACCESS_TOKEN") | |
) | |
def _hf_login(): | |
"""Login to HuggingFace Hub.""" | |
if HF_TOKEN: | |
try: | |
from huggingface_hub import login, whoami | |
login(token=HF_TOKEN, add_to_git_credential=True) | |
try: | |
user = whoami(token=HF_TOKEN) | |
print(f"β Logged in as: {user.get('name', user.get('id', 'unknown'))}") | |
except: | |
print("β HF login successful") | |
except Exception as e: | |
print(f"β HF login failed: {e}") | |
else: | |
print("β No HF_TOKEN found in environment") | |
# Login before loading models | |
_hf_login() | |
# Disable tokenizer parallelism warning | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# ===== LOAD TOKENIZER ===== | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) | |
print(f"β Tokenizer loaded from {MODEL_ID}") | |
except Exception as e: | |
print(f"β Failed to load tokenizer: {e}") | |
raise | |
# ===== HARMONY SETUP ===== | |
if HARMONY_AVAILABLE: | |
harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) | |
HARMONY_STOP_IDS = harmony_encoding.stop_tokens_for_assistant_actions() | |
else: | |
harmony_encoding = None | |
HARMONY_STOP_IDS = [] | |
# ===== MODEL LOADING WITH MX FORMAT SUPPORT ===== | |
def detect_mx_format(model) -> bool: | |
"""Check if model is using native MX format.""" | |
if not hasattr(model, 'model') or not hasattr(model.model, 'layers'): | |
return False | |
try: | |
first_layer = model.model.layers[0] | |
if hasattr(first_layer, 'block_sparse_moe'): | |
expert = first_layer.block_sparse_moe.experts[0] | |
if hasattr(expert, 'w1'): | |
# Check for MX format scale tensors | |
return hasattr(expert.w1, 'scales') | |
except: | |
pass | |
return False | |
def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM: | |
"""Load the base model with proper MX format handling.""" | |
import torch # Import torch here for ZeroGPU compatibility | |
print(f"\n{'='*50}") | |
print(f"Loading model: {MODEL_ID}") | |
print(f"MX Format Available: {_HAS_TRITON_KERNELS}") | |
print(f"{'='*50}\n") | |
# Load config to check model type | |
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) | |
# Build loading kwargs | |
load_kwargs = { | |
"trust_remote_code": True, | |
"device_map": device_map, | |
"low_cpu_mem_usage": True, | |
"token": HF_TOKEN, | |
"attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager", | |
} | |
if IS_GPT_OSS: | |
if _HAS_TRITON_KERNELS: | |
print("β Loading with native MX format support") | |
# For MX format, let the model handle its own dtype | |
load_kwargs["torch_dtype"] = "auto" | |
# Set environment variable to ensure MX is used | |
import os | |
os.environ["FORCE_MX_QUANTIZATION"] = "1" | |
else: | |
print("β No triton_kernels - falling back to bf16 (dequantized)") | |
print(" This will likely cause LoRA compatibility issues!") | |
# Load the model - torch imported inside function | |
import torch | |
load_kwargs["torch_dtype"] = torch.bfloat16 | |
# Explicitly disable MX | |
import os | |
os.environ["FORCE_MX_QUANTIZATION"] = "0" | |
else: | |
# Non-GPT-OSS models | |
import torch | |
load_kwargs["torch_dtype"] = torch.bfloat16 | |
try: | |
# Load the model | |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs) | |
# Verify format | |
print(f"Model loaded - dtype: {next(model.parameters()).dtype}") | |
if IS_GPT_OSS: | |
is_mx = detect_mx_format(model) | |
if is_mx: | |
print("β Confirmed: Using native MX format") | |
else: | |
print("β Model dequantized to bf16 - LoRA may fail") | |
# Set model config | |
if getattr(model.config, "pad_token_id", None) is None: | |
model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id | |
model.config.use_cache = True | |
return model | |
except Exception as e: | |
if "ragged_tma" in str(e): | |
print("\n" + "="*60) | |
print("ERROR: Triton version incompatibility detected!") | |
print("The model requires a specific Triton version with ragged_tma support.") | |
print("\nTo fix this, run:") | |
print("pip uninstall -y triton triton_kernels") | |
print("pip install --index-url https://download.pytorch.org/whl/nightly/cu121 triton") | |
print("pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels") | |
print("="*60 + "\n") | |
# Try to load without MX as fallback | |
print("Attempting to load model without MX format...") | |
import torch | |
load_kwargs["torch_dtype"] = torch.bfloat16 | |
os.environ["FORCE_MX_QUANTIZATION"] = "0" | |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs) | |
print("β Model loaded in bf16 mode (degraded performance)") | |
return model | |
else: | |
raise | |
def load_lora_adapter(model, adapter_id: str, subfolder: Optional[str] = None): | |
"""Load and attach LoRA adapter for bf16 model.""" | |
if not _HAS_PEFT: | |
raise RuntimeError("PEFT is required for LoRA adapters") | |
print(f"\n{'='*50}") | |
print(f"Loading LoRA: {adapter_id}") | |
if subfolder: | |
print(f"Subfolder: {subfolder}") | |
print(f"{'='*50}\n") | |
# Prepare kwargs for PEFT | |
peft_kwargs = {"token": HF_TOKEN, "is_trainable": False} | |
if subfolder: | |
peft_kwargs["subfolder"] = subfolder | |
try: | |
# Load adapter configuration | |
peft_config = PeftConfig.from_pretrained(adapter_id, **peft_kwargs) | |
print(f"LoRA config: r={peft_config.r}, alpha={peft_config.lora_alpha}") | |
# Load the adapter | |
model = PeftModel.from_pretrained(model, adapter_id, **peft_kwargs) | |
# Warning about potential mismatch | |
if IS_GPT_OSS: | |
print("β WARNING: LoRA may have been trained on MX format") | |
print(" Model is running in bf16 mode - there may be compatibility issues") | |
print(" If generation quality is poor, the LoRA may need retraining on bf16") | |
print("β LoRA adapter loaded") | |
# Optionally merge adapter | |
if MERGE_ADAPTER and hasattr(model, 'merge_and_unload'): | |
print("Merging adapter into base model...") | |
model = model.merge_and_unload() | |
print("β Adapter merged") | |
return model | |
except Exception as e: | |
print(f"β Failed to load LoRA: {e}") | |
print("Continuing with base model only") | |
return model | |
# ===== HARMONY FORMATTING ===== | |
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high"): | |
"""Create Harmony-formatted prompt.""" | |
if not HARMONY_AVAILABLE or not harmony_encoding: | |
# Fallback to chat template | |
if messages and messages[0].get("role") != "system": | |
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages | |
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
# Map reasoning effort | |
effort_map = { | |
"low": ReasoningEffort.LOW, | |
"medium": ReasoningEffort.MEDIUM, | |
"high": ReasoningEffort.HIGH | |
} | |
effort = effort_map.get(reasoning_effort.lower(), ReasoningEffort.HIGH) | |
# Build Harmony conversation | |
system_content = ( | |
SystemContent.new() | |
.with_model_identity("You are ChatGPT, a large language model trained by OpenAI.") | |
.with_reasoning_effort(effort) | |
.with_conversation_start_date(datetime.now().strftime("%Y-%m-%d")) | |
.with_knowledge_cutoff("2024-06") | |
.with_required_channels(REQUIRED_CHANNELS) | |
) | |
# Extract system prompt | |
sys_text = SYSTEM_PROMPT | |
rest = messages or [] | |
if rest and rest[0].get("role") == "system": | |
sys_text = rest[0].get("content", SYSTEM_PROMPT) | |
rest = rest[1:] | |
# Build messages | |
harmony_messages = [ | |
Message.from_role_and_content(Role.SYSTEM, system_content), | |
Message.from_role_and_content( | |
Role.DEVELOPER, | |
DeveloperContent.new().with_instructions(sys_text) | |
) | |
] | |
for msg in rest: | |
role = msg.get("role") | |
content = msg.get("content", "") | |
if role == "user": | |
harmony_messages.append(Message.from_role_and_content(Role.USER, content)) | |
elif role == "assistant": | |
harmony_messages.append( | |
Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final") | |
) | |
# Render to token IDs | |
convo = Conversation.from_messages(harmony_messages) | |
return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT) | |
def parse_harmony_response(tokens: List[int]) -> Dict[str, str]: | |
"""Parse Harmony response tokens into channels.""" | |
if not HARMONY_AVAILABLE or not harmony_encoding: | |
text = tokenizer.decode(tokens, skip_special_tokens=False) | |
return {"final": extract_final_channel(text), "raw": text} | |
try: | |
# Parse using Harmony | |
parsed = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT) | |
channels = {} | |
for msg in parsed: | |
channel = getattr(msg, 'channel', 'final') | |
if channel not in channels: | |
channels[channel] = "" | |
# Extract text content | |
content = msg.content | |
if isinstance(content, list): | |
text = "".join([getattr(part, "text", str(part)) for part in content]) | |
else: | |
text = getattr(content, "text", str(content)) | |
channels[channel] += text | |
# Ensure final channel exists | |
if "final" not in channels: | |
channels["final"] = " ".join(channels.values()) | |
return channels | |
except Exception as e: | |
print(f"Harmony parsing failed: {e}") | |
text = tokenizer.decode(tokens, skip_special_tokens=False) | |
return {"final": extract_final_channel(text), "raw": text} | |
def extract_final_channel(text: str) -> str: | |
"""Extract final channel from raw text.""" | |
# Look for <|channel|>final<|message|> | |
if "<|channel|>final<|message|>" in text: | |
parts = text.split("<|channel|>final<|message|>") | |
if len(parts) > 1: | |
final = parts[-1] | |
# Truncate at next marker | |
for marker in ["<|channel|>", "<|end|>", "<|return|>"]: | |
if marker in final: | |
final = final.split(marker)[0] | |
return final.strip() | |
# Fallback: return cleaned text | |
for marker in ["<|channel|>", "<|message|>", "<|end|>", "<|return|>"]: | |
text = text.replace(marker, " ") | |
return text.strip() | |
# ===== GENERATION ===== | |
def generate_on_gpu( | |
prompt, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
max_new_tokens: int, | |
do_sample: bool, | |
repetition_penalty: float, | |
seed: Optional[int] | |
) -> Dict[str, str]: | |
"""Run generation on GPU.""" | |
import torch # Import torch inside GPU function for ZeroGPU | |
try: | |
# Set seed if provided | |
if seed is not None: | |
torch.manual_seed(int(seed)) | |
# Load model | |
print("\nLoading model for generation...") | |
model = load_base_model("auto") | |
# Load LoRA if specified | |
if ADAPTER_ID: | |
model = load_lora_adapter(model, ADAPTER_ID, ADAPTER_SUBFOLDER) | |
model.eval() | |
# Prepare inputs | |
import torch # Make sure torch is available | |
device = next(model.parameters()).device | |
if HARMONY_AVAILABLE and isinstance(prompt, list): | |
# Harmony returns token IDs | |
input_ids = torch.tensor([prompt], dtype=torch.long, device=device) | |
else: | |
# String prompt | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(device) | |
attention_mask = torch.ones_like(input_ids) | |
prompt_len = input_ids.shape[1] | |
# Generate | |
print("Generating response...") | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k if top_k > 0 else None, | |
do_sample=do_sample, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=model.config.pad_token_id, | |
eos_token_id=HARMONY_STOP_IDS if HARMONY_STOP_IDS else tokenizer.eos_token_id, | |
no_repeat_ngram_size=3, | |
) | |
# Extract generated tokens | |
gen_tokens = outputs[0][prompt_len:].tolist() | |
# Truncate at stop tokens | |
for stop_id in HARMONY_STOP_IDS: | |
if stop_id in gen_tokens: | |
gen_tokens = gen_tokens[:gen_tokens.index(stop_id)] | |
break | |
# Parse response | |
channels = parse_harmony_response(gen_tokens) | |
return channels | |
except Exception as e: | |
error_msg = f"Generation failed: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
return {"final": f"Error: {str(e)}", "raw": error_msg} | |
finally: | |
# Cleanup | |
import torch | |
if 'model' in locals(): | |
del model | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# ===== GRADIO INTERFACE ===== | |
def chat_response( | |
message: str, | |
history: List[List[str]], | |
system_prompt: str, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
max_new_tokens: int, | |
do_sample: bool, | |
repetition_penalty: float, | |
seed: Optional[int], | |
reasoning_effort: str, | |
show_thinking: bool | |
) -> str: | |
"""Handle chat interaction.""" | |
try: | |
# Build conversation | |
messages = [{"role": "system", "content": system_prompt or SYSTEM_PROMPT}] | |
# Add history | |
for turn in history or []: | |
if isinstance(turn, (list, tuple)) and len(turn) >= 2: | |
user_msg, assistant_msg = turn[0], turn[1] | |
if user_msg: | |
messages.append({"role": "user", "content": str(user_msg)}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": str(assistant_msg)}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
# Create prompt | |
prompt = create_harmony_prompt(messages, reasoning_effort) | |
# Generate | |
channels = generate_on_gpu( | |
prompt, | |
temperature, | |
top_p, | |
top_k, | |
max_new_tokens, | |
do_sample, | |
repetition_penalty, | |
seed | |
) | |
# Format response | |
if show_thinking and len(channels) > 1: | |
response = "## Chain of Thought:\n\n" | |
for channel, content in channels.items(): | |
if channel != "final" and content: | |
response += f"### {channel.capitalize()}:\n{content}\n\n" | |
response += f"### Final Response:\n{channels.get('final', 'No response generated')}" | |
else: | |
response = channels.get("final", "No response generated") | |
return response | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# ===== BUILD UI ===== | |
with gr.Blocks(theme=gr.themes.Soft(), title="Mirel") as demo: | |
# Header with status | |
status_mx = "β MX Format" if _HAS_TRITON_KERNELS else "β No MX Support" | |
status_harmony = "β Harmony" if HARMONY_AVAILABLE else "β No Harmony" | |
gr.Markdown(f""" | |
# π€ Mirel β Chain-of-Thought Assistant | |
**Model:** `{MODEL_ID}` | **Adapter:** `{ADAPTER_ID or 'None'}` | |
**Status:** {status_mx} | {status_harmony} | {"β ZeroGPU" if ZEROGPU else "CPU Mode"} | |
{''' | |
β οΈ **WARNING: MX Format Support Missing!** | |
Install with: `pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels` | |
''' if IS_GPT_OSS and not _HAS_TRITON_KERNELS else ''} | |
""") | |
# System prompt | |
system_prompt = gr.Textbox( | |
label="System Prompt", | |
value=SYSTEM_PROMPT, | |
lines=2 | |
) | |
# Settings | |
with gr.Accordion("βοΈ Generation Settings", open=False): | |
with gr.Row(): | |
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") | |
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p") | |
top_k = gr.Slider(0, 200, value=50, step=1, label="Top-k") | |
with gr.Row(): | |
max_new_tokens = gr.Slider(16, 2048, value=MAX_NEW_TOKENS, step=16, label="Max tokens") | |
repetition_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label="Repetition penalty") | |
seed = gr.Number(value=None, label="Seed (optional)", precision=0) | |
with gr.Row(): | |
do_sample = gr.Checkbox(value=True, label="Sample") | |
show_thinking = gr.Checkbox(value=False, label="Show thinking channels") | |
reasoning_effort = gr.Radio( | |
["low", "medium", "high"], | |
value="high", | |
label="Reasoning effort" | |
) | |
# Chat interface | |
chat = gr.ChatInterface( | |
fn=chat_response, | |
additional_inputs=[ | |
system_prompt, | |
temperature, | |
top_p, | |
top_k, | |
max_new_tokens, | |
do_sample, | |
repetition_penalty, | |
seed, | |
reasoning_effort, | |
show_thinking | |
], | |
title=None, | |
examples=[ | |
["Hello! Can you introduce yourself?"], | |
["What's the capital of France?"], | |
["Explain quantum computing simply"], | |
["Write a haiku about coding"], | |
], | |
cache_examples=False, | |
) | |
# Footer | |
gr.Markdown(""" | |
--- | |
π‘ **Tips:** | |
- Enable "Show thinking channels" to see the model's reasoning process | |
- Adjust "Reasoning effort" for faster responses (low) or better quality (high) | |
- The model uses MX format on H200 GPUs for optimal performance | |
""") | |
# ===== LAUNCH ===== | |
if __name__ == "__main__": | |
print("\n" + "="*60) | |
print("MIREL READY TO LAUNCH") | |
print(f"Model: {MODEL_ID}") | |
print(f"Adapter: {ADAPTER_ID or 'None'}") | |
print(f"MX Format: {'ENABLED' if _HAS_TRITON_KERNELS else 'DISABLED'}") | |
print(f"Harmony: {'ENABLED' if HARMONY_AVAILABLE else 'DISABLED'}") | |
print("="*60 + "\n") | |
demo.queue(max_size=10).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) |