File size: 22,406 Bytes
ae231bc
 
3a8756f
 
ae231bc
 
f7e1fb5
6228595
 
9dc2118
f7e1fb5
ae231bc
5c9afc5
ae231bc
02fd900
ae231bc
5c9afc5
 
 
9dc2118
 
 
6228595
5c9afc5
 
bbabb73
ae231bc
 
 
 
 
 
 
 
 
 
 
 
 
 
6228595
ae231bc
6228595
ae231bc
 
6228595
 
 
 
 
 
 
 
 
3a8756f
 
 
 
 
 
 
6228595
 
ae231bc
6228595
 
ae231bc
6228595
 
 
 
5c9afc5
6228595
5c9afc5
6228595
ae231bc
6228595
a2f6c58
ae231bc
6228595
 
ae231bc
 
2e87c77
d76bf3e
2e87c77
 
6228595
 
ae231bc
 
 
 
 
6228595
 
 
 
ae231bc
6228595
ae231bc
6228595
ae231bc
6228595
7779abb
ae231bc
6228595
d76bf3e
 
6228595
ae231bc
 
6228595
ae231bc
6228595
ae231bc
 
6228595
 
 
 
 
 
 
5c9afc5
6228595
ae231bc
6228595
 
 
 
5c9afc5
6228595
 
 
 
 
 
 
 
 
 
 
 
 
9dc2118
 
6228595
 
 
 
5c9afc5
6228595
 
5c9afc5
6228595
 
 
 
 
 
 
 
5c9afc5
6228595
 
 
f7e1fb5
 
 
 
 
 
5c9afc5
6228595
 
9dc2118
 
6228595
f7e1fb5
 
 
 
6228595
 
9dc2118
6228595
5c9afc5
f7e1fb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dc2118
f7e1fb5
 
 
 
 
6228595
f7e1fb5
ae231bc
6228595
3a8756f
6228595
 
5c9afc5
6228595
 
 
 
 
ae231bc
6228595
 
 
 
5c9afc5
6228595
 
 
 
5c9afc5
6228595
 
5c9afc5
3a8756f
 
 
 
 
5c9afc5
3a8756f
6228595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c9afc5
6228595
 
 
 
 
 
 
ae231bc
6228595
 
 
 
 
 
 
 
 
5c9afc5
6228595
 
 
 
 
 
5c9afc5
6228595
 
 
 
 
 
ae231bc
6228595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae231bc
 
6228595
 
ae231bc
6228595
ae231bc
 
6228595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae231bc
 
6228595
ae231bc
3d65633
6228595
 
 
 
 
 
 
 
 
 
 
9dc2118
 
3d65633
6228595
ae231bc
 
 
6228595
 
 
 
 
 
 
 
 
 
5c9afc5
9dc2118
ae231bc
6228595
 
 
 
ae231bc
6228595
 
 
 
 
 
5c9afc5
ae231bc
6228595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae231bc
5c9afc5
6228595
5c9afc5
 
6228595
 
 
 
ae231bc
5c9afc5
6228595
ae231bc
 
6228595
ae231bc
6228595
 
 
 
ae231bc
 
9dc2118
6228595
ae231bc
 
 
 
 
6228595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae231bc
6228595
 
7779abb
6228595
 
 
 
 
 
 
 
7779abb
6228595
 
7779abb
5c9afc5
6228595
 
5c9afc5
6228595
ae231bc
6228595
 
 
 
 
 
 
ae231bc
 
 
6228595
ae231bc
 
 
6228595
 
ae231bc
6228595
ae231bc
6228595
ae231bc
6228595
 
ed0198d
6228595
 
 
 
 
 
 
 
 
 
 
 
ae231bc
6228595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae231bc
 
6228595
 
 
ae231bc
6228595
 
ae231bc
6228595
ae231bc
6228595
 
ae231bc
6228595
ae231bc
6228595
ae231bc
 
5c9afc5
ae231bc
6228595
ae231bc
6228595
 
 
 
 
 
 
 
 
 
ae231bc
6228595
ae231bc
 
6228595
 
 
ae231bc
 
 
6228595
 
 
 
 
 
 
 
 
 
 
ae231bc
6228595
 
 
 
 
 
 
 
 
 
ae231bc
 
5c9afc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
"""
Mirel Harmony Inference – HF Space (Gradio)
ZeroGPU-ready, Harmony formatting, bf16 mode for GPT-OSS-20B
Proper LoRA adapter loading (MX format not available in stable releases)
Single file: app.py
"""
from __future__ import annotations

# ===== MAIN IMPORTS =====
import os, gc, json, warnings, traceback
import subprocess, sys
from dataclasses import dataclass
from typing import List, Dict, Optional, Any, Union
from datetime import datetime
import gradio as gr
import spaces  # required for ZeroGPU
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import numpy as np

# IMPORTANT: Don't import torch at module level for ZeroGPU
# It will be imported inside GPU-decorated functions

# Suppress warnings
warnings.filterwarnings("ignore", message=".*microscaling.*")
warnings.filterwarnings("ignore", message=".*mx.*")

# Import Harmony components
try:
    from openai_harmony import (
        Author,
        Conversation,
        HarmonyEncodingName,
        Message,
        Role,
        SystemContent,
        DeveloperContent,
        load_harmony_encoding,
        ReasoningEffort
    )
    HARMONY_AVAILABLE = True
    print("βœ“ OpenAI Harmony loaded successfully")
except ImportError:
    print("⚠ openai_harmony not installed. Install with: pip install openai-harmony")
    HARMONY_AVAILABLE = False

# Import PEFT for LoRA support
try:
    from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
    _HAS_PEFT = True
    print("βœ“ PEFT loaded successfully")
except Exception:
    _HAS_PEFT = False
    print("⚠ PEFT not available. Install with: pip install peft")

# Note: MX format requires unreleased Triton features
# We'll use bf16 mode which works fine for inference
_HAS_TRITON_KERNELS = False
USE_MX_FORMAT = False

print("Note: Using bf16 mode (MX format requires unreleased Triton features)")
print("This will work fine but use more memory than native MX format")

# ===== CONFIGURATION =====
MODEL_ID          = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
ADAPTER_ID        = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b")
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516")
ATTN_IMPL         = os.getenv("ATTN_IMPL", "eager")
SYSTEM_PROMPT     = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
MAX_NEW_TOKENS    = int(os.getenv("MAX_NEW_TOKENS", "512"))
ZEROGPU           = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "1")) == "1"
MERGE_ADAPTER     = os.getenv("MERGE_ADAPTER", "0") == "1"

# Detect if using GPT-OSS model
IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower()
USE_MX_FORMAT = IS_GPT_OSS and _HAS_TRITON_KERNELS

# Harmony channels for chain-of-thought
REQUIRED_CHANNELS = ["analysis", "commentary", "final"]

# HF Authentication
HF_TOKEN = (
    os.getenv("HF_TOKEN") 
    or os.getenv("HUGGING_FACE_HUB_TOKEN") 
    or os.getenv("HUGGINGFACEHUB_API_TOKEN")
    or os.getenv("HF_ACCESS_TOKEN")
)

def _hf_login():
    """Login to HuggingFace Hub."""
    if HF_TOKEN:
        try:
            from huggingface_hub import login, whoami
            login(token=HF_TOKEN, add_to_git_credential=True)
            try:
                user = whoami(token=HF_TOKEN)
                print(f"βœ“ Logged in as: {user.get('name', user.get('id', 'unknown'))}")
            except:
                print("βœ“ HF login successful")
        except Exception as e:
            print(f"⚠ HF login failed: {e}")
    else:
        print("⚠ No HF_TOKEN found in environment")

# Login before loading models
_hf_login()

# Disable tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ===== LOAD TOKENIZER =====
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
    print(f"βœ“ Tokenizer loaded from {MODEL_ID}")
except Exception as e:
    print(f"βœ— Failed to load tokenizer: {e}")
    raise

# ===== HARMONY SETUP =====
if HARMONY_AVAILABLE:
    harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
    HARMONY_STOP_IDS = harmony_encoding.stop_tokens_for_assistant_actions()
else:
    harmony_encoding = None
    HARMONY_STOP_IDS = []

# ===== MODEL LOADING WITH MX FORMAT SUPPORT =====

def detect_mx_format(model) -> bool:
    """Check if model is using native MX format."""
    if not hasattr(model, 'model') or not hasattr(model.model, 'layers'):
        return False
    
    try:
        first_layer = model.model.layers[0]
        if hasattr(first_layer, 'block_sparse_moe'):
            expert = first_layer.block_sparse_moe.experts[0]
            if hasattr(expert, 'w1'):
                # Check for MX format scale tensors
                return hasattr(expert.w1, 'scales')
    except:
        pass
    return False

def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
    """Load the base model with proper MX format handling."""
    import torch  # Import torch here for ZeroGPU compatibility
    
    print(f"\n{'='*50}")
    print(f"Loading model: {MODEL_ID}")
    print(f"MX Format Available: {_HAS_TRITON_KERNELS}")
    print(f"{'='*50}\n")
    
    # Load config to check model type
    config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
    
    # Build loading kwargs
    load_kwargs = {
        "trust_remote_code": True,
        "device_map": device_map,
        "low_cpu_mem_usage": True,
        "token": HF_TOKEN,
        "attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
    }
    
    if IS_GPT_OSS:
        if _HAS_TRITON_KERNELS:
            print("β†’ Loading with native MX format support")
            # For MX format, let the model handle its own dtype
            load_kwargs["torch_dtype"] = "auto"
            
            # Set environment variable to ensure MX is used
            import os
            os.environ["FORCE_MX_QUANTIZATION"] = "1"
        else:
            print("⚠ No triton_kernels - falling back to bf16 (dequantized)")
            print("  This will likely cause LoRA compatibility issues!")
            # Load the model - torch imported inside function
            import torch
            load_kwargs["torch_dtype"] = torch.bfloat16
            
            # Explicitly disable MX
            import os
            os.environ["FORCE_MX_QUANTIZATION"] = "0"
    else:
        # Non-GPT-OSS models
        import torch
        load_kwargs["torch_dtype"] = torch.bfloat16
    
    try:
        # Load the model
        model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
        
        # Verify format
        print(f"Model loaded - dtype: {next(model.parameters()).dtype}")
        if IS_GPT_OSS:
            is_mx = detect_mx_format(model)
            if is_mx:
                print("βœ“ Confirmed: Using native MX format")
            else:
                print("⚠ Model dequantized to bf16 - LoRA may fail")
        
        # Set model config
        if getattr(model.config, "pad_token_id", None) is None:
            model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
        model.config.use_cache = True
        
        return model
        
    except Exception as e:
        if "ragged_tma" in str(e):
            print("\n" + "="*60)
            print("ERROR: Triton version incompatibility detected!")
            print("The model requires a specific Triton version with ragged_tma support.")
            print("\nTo fix this, run:")
            print("pip uninstall -y triton triton_kernels")
            print("pip install --index-url https://download.pytorch.org/whl/nightly/cu121 triton")
            print("pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels")
            print("="*60 + "\n")
            
            # Try to load without MX as fallback
            print("Attempting to load model without MX format...")
            import torch
            load_kwargs["torch_dtype"] = torch.bfloat16
            os.environ["FORCE_MX_QUANTIZATION"] = "0"
            model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
            print("βœ“ Model loaded in bf16 mode (degraded performance)")
            return model
        else:
            raise

def load_lora_adapter(model, adapter_id: str, subfolder: Optional[str] = None):
    """Load and attach LoRA adapter for bf16 model."""
    if not _HAS_PEFT:
        raise RuntimeError("PEFT is required for LoRA adapters")
    
    print(f"\n{'='*50}")
    print(f"Loading LoRA: {adapter_id}")
    if subfolder:
        print(f"Subfolder: {subfolder}")
    print(f"{'='*50}\n")
    
    # Prepare kwargs for PEFT
    peft_kwargs = {"token": HF_TOKEN, "is_trainable": False}
    if subfolder:
        peft_kwargs["subfolder"] = subfolder
    
    try:
        # Load adapter configuration
        peft_config = PeftConfig.from_pretrained(adapter_id, **peft_kwargs)
        print(f"LoRA config: r={peft_config.r}, alpha={peft_config.lora_alpha}")
        
        # Load the adapter
        model = PeftModel.from_pretrained(model, adapter_id, **peft_kwargs)
        
        # Warning about potential mismatch
        if IS_GPT_OSS:
            print("⚠ WARNING: LoRA may have been trained on MX format")
            print("  Model is running in bf16 mode - there may be compatibility issues")
            print("  If generation quality is poor, the LoRA may need retraining on bf16")
        
        print("βœ“ LoRA adapter loaded")
        
        # Optionally merge adapter
        if MERGE_ADAPTER and hasattr(model, 'merge_and_unload'):
            print("Merging adapter into base model...")
            model = model.merge_and_unload()
            print("βœ“ Adapter merged")
        
        return model
        
    except Exception as e:
        print(f"βœ— Failed to load LoRA: {e}")
        print("Continuing with base model only")
        return model

# ===== HARMONY FORMATTING =====

def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high"):
    """Create Harmony-formatted prompt."""
    if not HARMONY_AVAILABLE or not harmony_encoding:
        # Fallback to chat template
        if messages and messages[0].get("role") != "system":
            messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages
        return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    
    # Map reasoning effort
    effort_map = {
        "low": ReasoningEffort.LOW,
        "medium": ReasoningEffort.MEDIUM, 
        "high": ReasoningEffort.HIGH
    }
    effort = effort_map.get(reasoning_effort.lower(), ReasoningEffort.HIGH)
    
    # Build Harmony conversation
    system_content = (
        SystemContent.new()
        .with_model_identity("You are ChatGPT, a large language model trained by OpenAI.")
        .with_reasoning_effort(effort)
        .with_conversation_start_date(datetime.now().strftime("%Y-%m-%d"))
        .with_knowledge_cutoff("2024-06")
        .with_required_channels(REQUIRED_CHANNELS)
    )
    
    # Extract system prompt
    sys_text = SYSTEM_PROMPT
    rest = messages or []
    if rest and rest[0].get("role") == "system":
        sys_text = rest[0].get("content", SYSTEM_PROMPT)
        rest = rest[1:]
    
    # Build messages
    harmony_messages = [
        Message.from_role_and_content(Role.SYSTEM, system_content),
        Message.from_role_and_content(
            Role.DEVELOPER, 
            DeveloperContent.new().with_instructions(sys_text)
        )
    ]
    
    for msg in rest:
        role = msg.get("role")
        content = msg.get("content", "")
        if role == "user":
            harmony_messages.append(Message.from_role_and_content(Role.USER, content))
        elif role == "assistant":
            harmony_messages.append(
                Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final")
            )
    
    # Render to token IDs
    convo = Conversation.from_messages(harmony_messages)
    return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)

def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
    """Parse Harmony response tokens into channels."""
    if not HARMONY_AVAILABLE or not harmony_encoding:
        text = tokenizer.decode(tokens, skip_special_tokens=False)
        return {"final": extract_final_channel(text), "raw": text}
    
    try:
        # Parse using Harmony
        parsed = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
        
        channels = {}
        for msg in parsed:
            channel = getattr(msg, 'channel', 'final')
            if channel not in channels:
                channels[channel] = ""
            
            # Extract text content
            content = msg.content
            if isinstance(content, list):
                text = "".join([getattr(part, "text", str(part)) for part in content])
            else:
                text = getattr(content, "text", str(content))
            
            channels[channel] += text
        
        # Ensure final channel exists
        if "final" not in channels:
            channels["final"] = " ".join(channels.values())
        
        return channels
        
    except Exception as e:
        print(f"Harmony parsing failed: {e}")
        text = tokenizer.decode(tokens, skip_special_tokens=False)
        return {"final": extract_final_channel(text), "raw": text}

def extract_final_channel(text: str) -> str:
    """Extract final channel from raw text."""
    # Look for <|channel|>final<|message|>
    if "<|channel|>final<|message|>" in text:
        parts = text.split("<|channel|>final<|message|>")
        if len(parts) > 1:
            final = parts[-1]
            # Truncate at next marker
            for marker in ["<|channel|>", "<|end|>", "<|return|>"]:
                if marker in final:
                    final = final.split(marker)[0]
            return final.strip()
    
    # Fallback: return cleaned text
    for marker in ["<|channel|>", "<|message|>", "<|end|>", "<|return|>"]:
        text = text.replace(marker, " ")
    return text.strip()

# ===== GENERATION =====

@spaces.GPU(duration=120)
def generate_on_gpu(
    prompt,
    temperature: float,
    top_p: float,
    top_k: int,
    max_new_tokens: int,
    do_sample: bool,
    repetition_penalty: float,
    seed: Optional[int]
) -> Dict[str, str]:
    """Run generation on GPU."""
    import torch  # Import torch inside GPU function for ZeroGPU
    
    try:
        # Set seed if provided
        if seed is not None:
            torch.manual_seed(int(seed))
        
        # Load model
        print("\nLoading model for generation...")
        model = load_base_model("auto")
        
        # Load LoRA if specified
        if ADAPTER_ID:
            model = load_lora_adapter(model, ADAPTER_ID, ADAPTER_SUBFOLDER)
        
        model.eval()
        
        # Prepare inputs
        import torch  # Make sure torch is available
        device = next(model.parameters()).device
        
        if HARMONY_AVAILABLE and isinstance(prompt, list):
            # Harmony returns token IDs
            input_ids = torch.tensor([prompt], dtype=torch.long, device=device)
        else:
            # String prompt
            inputs = tokenizer(prompt, return_tensors="pt")
            input_ids = inputs["input_ids"].to(device)
        
        attention_mask = torch.ones_like(input_ids)
        prompt_len = input_ids.shape[1]
        
        # Generate
        print("Generating response...")
        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k if top_k > 0 else None,
                do_sample=do_sample,
                repetition_penalty=repetition_penalty,
                pad_token_id=model.config.pad_token_id,
                eos_token_id=HARMONY_STOP_IDS if HARMONY_STOP_IDS else tokenizer.eos_token_id,
                no_repeat_ngram_size=3,
            )
        
        # Extract generated tokens
        gen_tokens = outputs[0][prompt_len:].tolist()
        
        # Truncate at stop tokens
        for stop_id in HARMONY_STOP_IDS:
            if stop_id in gen_tokens:
                gen_tokens = gen_tokens[:gen_tokens.index(stop_id)]
                break
        
        # Parse response
        channels = parse_harmony_response(gen_tokens)
        
        return channels
        
    except Exception as e:
        error_msg = f"Generation failed: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return {"final": f"Error: {str(e)}", "raw": error_msg}
    
    finally:
        # Cleanup
        import torch
        if 'model' in locals():
            del model
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

# ===== GRADIO INTERFACE =====

def chat_response(
    message: str,
    history: List[List[str]],
    system_prompt: str,
    temperature: float,
    top_p: float,
    top_k: int,
    max_new_tokens: int,
    do_sample: bool,
    repetition_penalty: float,
    seed: Optional[int],
    reasoning_effort: str,
    show_thinking: bool
) -> str:
    """Handle chat interaction."""
    try:
        # Build conversation
        messages = [{"role": "system", "content": system_prompt or SYSTEM_PROMPT}]
        
        # Add history
        for turn in history or []:
            if isinstance(turn, (list, tuple)) and len(turn) >= 2:
                user_msg, assistant_msg = turn[0], turn[1]
                if user_msg:
                    messages.append({"role": "user", "content": str(user_msg)})
                if assistant_msg:
                    messages.append({"role": "assistant", "content": str(assistant_msg)})
        
        # Add current message
        messages.append({"role": "user", "content": message})
        
        # Create prompt
        prompt = create_harmony_prompt(messages, reasoning_effort)
        
        # Generate
        channels = generate_on_gpu(
            prompt,
            temperature,
            top_p,
            top_k,
            max_new_tokens,
            do_sample,
            repetition_penalty,
            seed
        )
        
        # Format response
        if show_thinking and len(channels) > 1:
            response = "## Chain of Thought:\n\n"
            for channel, content in channels.items():
                if channel != "final" and content:
                    response += f"### {channel.capitalize()}:\n{content}\n\n"
            response += f"### Final Response:\n{channels.get('final', 'No response generated')}"
        else:
            response = channels.get("final", "No response generated")
        
        return response
        
    except Exception as e:
        return f"Error: {str(e)}"

# ===== BUILD UI =====

with gr.Blocks(theme=gr.themes.Soft(), title="Mirel") as demo:
    # Header with status
    status_mx = "βœ… MX Format" if _HAS_TRITON_KERNELS else "❌ No MX Support"
    status_harmony = "βœ… Harmony" if HARMONY_AVAILABLE else "❌ No Harmony"
    
    gr.Markdown(f"""
    # πŸ€– Mirel – Chain-of-Thought Assistant
    
    **Model:** `{MODEL_ID}` | **Adapter:** `{ADAPTER_ID or 'None'}`  
    **Status:** {status_mx} | {status_harmony} | {"βœ… ZeroGPU" if ZEROGPU else "CPU Mode"}
    
    {'''
    ⚠️ **WARNING: MX Format Support Missing!**
    Install with: `pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels`
    ''' if IS_GPT_OSS and not _HAS_TRITON_KERNELS else ''}
    """)
    
    # System prompt
    system_prompt = gr.Textbox(
        label="System Prompt",
        value=SYSTEM_PROMPT,
        lines=2
    )
    
    # Settings
    with gr.Accordion("βš™οΈ Generation Settings", open=False):
        with gr.Row():
            temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
            top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
            top_k = gr.Slider(0, 200, value=50, step=1, label="Top-k")
        
        with gr.Row():
            max_new_tokens = gr.Slider(16, 2048, value=MAX_NEW_TOKENS, step=16, label="Max tokens")
            repetition_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label="Repetition penalty")
            seed = gr.Number(value=None, label="Seed (optional)", precision=0)
        
        with gr.Row():
            do_sample = gr.Checkbox(value=True, label="Sample")
            show_thinking = gr.Checkbox(value=False, label="Show thinking channels")
            reasoning_effort = gr.Radio(
                ["low", "medium", "high"],
                value="high",
                label="Reasoning effort"
            )
    
    # Chat interface
    chat = gr.ChatInterface(
        fn=chat_response,
        additional_inputs=[
            system_prompt,
            temperature,
            top_p,
            top_k,
            max_new_tokens,
            do_sample,
            repetition_penalty,
            seed,
            reasoning_effort,
            show_thinking
        ],
        title=None,
        examples=[
            ["Hello! Can you introduce yourself?"],
            ["What's the capital of France?"],
            ["Explain quantum computing simply"],
            ["Write a haiku about coding"],
        ],
        cache_examples=False,
    )
    
    # Footer
    gr.Markdown("""
    ---
    πŸ’‘ **Tips:**
    - Enable "Show thinking channels" to see the model's reasoning process
    - Adjust "Reasoning effort" for faster responses (low) or better quality (high)
    - The model uses MX format on H200 GPUs for optimal performance
    """)

# ===== LAUNCH =====
if __name__ == "__main__":
    print("\n" + "="*60)
    print("MIREL READY TO LAUNCH")
    print(f"Model: {MODEL_ID}")
    print(f"Adapter: {ADAPTER_ID or 'None'}")
    print(f"MX Format: {'ENABLED' if _HAS_TRITON_KERNELS else 'DISABLED'}")
    print(f"Harmony: {'ENABLED' if HARMONY_AVAILABLE else 'DISABLED'}")
    print("="*60 + "\n")
    
    demo.queue(max_size=10).launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )