AbstractPhil commited on
Commit
f7e1fb5
Β·
1 Parent(s): 1189725
Files changed (3) hide show
  1. app.py +67 -49
  2. install.sh +14 -5
  3. setup.py +51 -20
app.py CHANGED
@@ -16,40 +16,24 @@ gradio>=5.42.0
16
  triton>=3.4.0
17
  git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
18
  """
 
19
 
20
- # ===== SETUP: Ensure triton_kernels is installed for MX format =====
21
- import subprocess
22
- import sys
23
-
24
- def ensure_triton_kernels():
25
- """Ensure triton_kernels is installed for MX format support on H200."""
26
  try:
27
  import triton_kernels
28
- print("βœ“ triton_kernels already installed - MX format supported")
29
- return True
30
  except ImportError:
31
- print("Installing triton_kernels for MX format support...")
32
- try:
33
- subprocess.check_call([
34
- sys.executable, "-m", "pip", "install",
35
- "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
36
- ])
37
- print("βœ“ triton_kernels installed successfully")
38
- # Force reimport
39
- import importlib
40
- import site
41
- importlib.reload(site)
42
- return True
43
- except subprocess.CalledProcessError as e:
44
- print(f"βœ— Failed to install triton_kernels: {e}")
45
- print("ERROR: MX format will NOT work properly without triton_kernels!")
46
- return False
47
-
48
- # Install triton_kernels before other imports
49
- _TRITON_INSTALL_SUCCESS = ensure_triton_kernels()
50
 
51
  # ===== MAIN IMPORTS =====
52
  import os, gc, json, torch, warnings, traceback
 
53
  from dataclasses import dataclass
54
  from typing import List, Dict, Optional, Any, Union
55
  from datetime import datetime
@@ -90,14 +74,17 @@ except Exception:
90
  _HAS_PEFT = False
91
  print("⚠ PEFT not available. Install with: pip install peft")
92
 
93
- # Check for triton_kernels (required for MX format)
94
  try:
95
  import triton_kernels
 
 
96
  _HAS_TRITON_KERNELS = True
97
- print("βœ“ triton_kernels loaded - MX format enabled")
98
- except ImportError:
99
  _HAS_TRITON_KERNELS = False
100
- print("βœ— triton_kernels not available - MX format disabled!")
 
101
 
102
  # ===== CONFIGURATION =====
103
  MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
@@ -202,33 +189,64 @@ def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
202
  if IS_GPT_OSS:
203
  if _HAS_TRITON_KERNELS:
204
  print("β†’ Loading with native MX format support")
205
- load_kwargs["torch_dtype"] = "auto" # Let model use native MX
 
 
 
 
 
206
  else:
207
  print("⚠ No triton_kernels - falling back to bf16 (dequantized)")
208
  print(" This will likely cause LoRA compatibility issues!")
209
  load_kwargs["torch_dtype"] = torch.bfloat16
 
 
 
 
210
  else:
211
  # Non-GPT-OSS models
212
  load_kwargs["torch_dtype"] = torch.bfloat16
213
 
214
- # Load the model
215
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
216
-
217
- # Verify format
218
- print(f"Model loaded - dtype: {next(model.parameters()).dtype}")
219
- if IS_GPT_OSS:
220
- is_mx = detect_mx_format(model)
221
- if is_mx:
222
- print("βœ“ Confirmed: Using native MX format")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  else:
224
- print("⚠ Model dequantized to bf16 - LoRA may fail")
225
-
226
- # Set model config
227
- if getattr(model.config, "pad_token_id", None) is None:
228
- model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
229
- model.config.use_cache = True
230
-
231
- return model
232
 
233
  def load_lora_adapter(model, adapter_id: str, subfolder: Optional[str] = None):
234
  """Load and attach LoRA adapter with MX format handling."""
 
16
  triton>=3.4.0
17
  git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
18
  """
19
+ from __future__ import annotations
20
 
21
+ # Import setup to fix Triton if needed
22
+ try:
23
+ import setup # This will run install.sh if Triton needs fixing
24
+ except ImportError:
25
+ print("No setup.py found - checking Triton manually")
26
+ # Fallback check
27
  try:
28
  import triton_kernels
29
+ from triton.tools.ragged_tma import load_ragged
30
+ print("βœ“ Triton configured correctly")
31
  except ImportError:
32
+ print("⚠ Triton not configured for MX - run install.sh")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # ===== MAIN IMPORTS =====
35
  import os, gc, json, torch, warnings, traceback
36
+ import subprocess, sys
37
  from dataclasses import dataclass
38
  from typing import List, Dict, Optional, Any, Union
39
  from datetime import datetime
 
74
  _HAS_PEFT = False
75
  print("⚠ PEFT not available. Install with: pip install peft")
76
 
77
+ # Check for triton_kernels after setup
78
  try:
79
  import triton_kernels
80
+ # Also check for the specific module that was missing
81
+ from triton.tools.ragged_tma import load_ragged, store_ragged
82
  _HAS_TRITON_KERNELS = True
83
+ print("βœ“ triton_kernels loaded with ragged_tma support - MX format enabled")
84
+ except ImportError as e:
85
  _HAS_TRITON_KERNELS = False
86
+ print(f"βœ— triton_kernels not fully functional: {e}")
87
+ print("MX format will fall back to bf16 - LoRA may not work correctly")
88
 
89
  # ===== CONFIGURATION =====
90
  MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
 
189
  if IS_GPT_OSS:
190
  if _HAS_TRITON_KERNELS:
191
  print("β†’ Loading with native MX format support")
192
+ # For MX format, let the model handle its own dtype
193
+ load_kwargs["torch_dtype"] = "auto"
194
+
195
+ # Set environment variable to ensure MX is used
196
+ import os
197
+ os.environ["FORCE_MX_QUANTIZATION"] = "1"
198
  else:
199
  print("⚠ No triton_kernels - falling back to bf16 (dequantized)")
200
  print(" This will likely cause LoRA compatibility issues!")
201
  load_kwargs["torch_dtype"] = torch.bfloat16
202
+
203
+ # Explicitly disable MX
204
+ import os
205
+ os.environ["FORCE_MX_QUANTIZATION"] = "0"
206
  else:
207
  # Non-GPT-OSS models
208
  load_kwargs["torch_dtype"] = torch.bfloat16
209
 
210
+ try:
211
+ # Load the model
212
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
213
+
214
+ # Verify format
215
+ print(f"Model loaded - dtype: {next(model.parameters()).dtype}")
216
+ if IS_GPT_OSS:
217
+ is_mx = detect_mx_format(model)
218
+ if is_mx:
219
+ print("βœ“ Confirmed: Using native MX format")
220
+ else:
221
+ print("⚠ Model dequantized to bf16 - LoRA may fail")
222
+
223
+ # Set model config
224
+ if getattr(model.config, "pad_token_id", None) is None:
225
+ model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
226
+ model.config.use_cache = True
227
+
228
+ return model
229
+
230
+ except Exception as e:
231
+ if "ragged_tma" in str(e):
232
+ print("\n" + "="*60)
233
+ print("ERROR: Triton version incompatibility detected!")
234
+ print("The model requires a specific Triton version with ragged_tma support.")
235
+ print("\nTo fix this, run:")
236
+ print("pip uninstall -y triton triton_kernels")
237
+ print("pip install --index-url https://download.pytorch.org/whl/nightly/cu121 triton")
238
+ print("pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels")
239
+ print("="*60 + "\n")
240
+
241
+ # Try to load without MX as fallback
242
+ print("Attempting to load model without MX format...")
243
+ load_kwargs["torch_dtype"] = torch.bfloat16
244
+ os.environ["FORCE_MX_QUANTIZATION"] = "0"
245
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
246
+ print("βœ“ Model loaded in bf16 mode (degraded performance)")
247
+ return model
248
  else:
249
+ raise
 
 
 
 
 
 
 
250
 
251
  def load_lora_adapter(model, adapter_id: str, subfolder: Optional[str] = None):
252
  """Load and attach LoRA adapter with MX format handling."""
install.sh CHANGED
@@ -10,7 +10,6 @@ pip install --upgrade pip
10
  pip install huggingface_hub>=0.34.0
11
  pip install transformers>=4.55.0
12
  pip install accelerate>=0.33.0
13
- pip install torch>=2.4.0
14
  pip install gradio>=5.42.0
15
  pip install spaces
16
 
@@ -21,11 +20,21 @@ pip install bitsandbytes>=0.43.1
21
  # Install Harmony format
22
  pip install openai-harmony
23
 
24
- # Install Triton and MX format support
25
- pip install triton>=3.4.0
 
26
 
27
- # CRITICAL: Install triton_kernels from git subdirectory
28
- # This is REQUIRED for MX format on H200 GPUs
 
 
 
 
 
 
 
 
 
29
  echo "Installing triton_kernels (REQUIRED for MX format)..."
30
  pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
31
 
 
10
  pip install huggingface_hub>=0.34.0
11
  pip install transformers>=4.55.0
12
  pip install accelerate>=0.33.0
 
13
  pip install gradio>=5.42.0
14
  pip install spaces
15
 
 
20
  # Install Harmony format
21
  pip install openai-harmony
22
 
23
+ # FIX TRITON FOR MX FORMAT
24
+ # The standard triton doesn't have ragged_tma module needed for MX
25
+ echo "Fixing Triton installation for MX format..."
26
 
27
+ # Clean existing triton installations
28
+ pip uninstall -y triton triton_kernels 2>/dev/null || true
29
+
30
+ # Install PyTorch nightly (includes compatible Triton)
31
+ echo "Installing PyTorch nightly with compatible Triton..."
32
+ pip install --upgrade --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
33
+
34
+ # Install Triton from PyTorch nightly
35
+ pip install --upgrade --index-url https://download.pytorch.org/whl/nightly/cu121 triton
36
+
37
+ # Install triton_kernels from source
38
  echo "Installing triton_kernels (REQUIRED for MX format)..."
39
  pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
40
 
setup.py CHANGED
@@ -1,31 +1,62 @@
1
  """
2
- setup.py - Run this at the start of app.py to ensure triton_kernels is installed
3
- Add this to the top of your app.py file in HF Spaces
4
  """
5
 
6
  import subprocess
7
  import sys
 
8
 
9
- def ensure_triton_kernels():
10
- """Ensure triton_kernels is installed for MX format support."""
11
  try:
 
12
  import triton_kernels
13
- print("βœ“ triton_kernels already installed")
 
14
  return True
15
  except ImportError:
16
- print("Installing triton_kernels for MX format support...")
17
- try:
18
- subprocess.check_call([
19
- sys.executable, "-m", "pip", "install",
20
- "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
21
- ])
22
- print("βœ“ triton_kernels installed successfully")
23
- return True
24
- except subprocess.CalledProcessError as e:
25
- print(f"βœ— Failed to install triton_kernels: {e}")
26
- print("WARNING: MX format will fall back to bf16, LoRA may not work!")
27
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Run at import time
30
- if __name__ != "__main__": # When imported
31
- ensure_triton_kernels()
 
1
  """
2
+ setup.py - Run this at the start of app.py to ensure proper Triton installation
3
+ Add: import setup # at the top of app.py after the docstring
4
  """
5
 
6
  import subprocess
7
  import sys
8
+ import os
9
 
10
+ def fix_triton_installation():
11
+ """Fix Triton for MX format by running install.sh if needed."""
12
  try:
13
+ # Check if we have the right triton
14
  import triton_kernels
15
+ from triton.tools.ragged_tma import load_ragged, store_ragged
16
+ print("βœ“ Triton already properly configured for MX format")
17
  return True
18
  except ImportError:
19
+ print("Triton not properly configured for MX format")
20
+ print("Running install.sh to fix dependencies...")
21
+
22
+ # Check if install.sh exists
23
+ if os.path.exists("install.sh"):
24
+ try:
25
+ # Make it executable and run it
26
+ subprocess.check_call(["chmod", "+x", "install.sh"])
27
+ subprocess.check_call(["./install.sh"])
28
+ print("βœ“ Dependencies installed via install.sh")
29
+ return True
30
+ except subprocess.CalledProcessError as e:
31
+ print(f"Error running install.sh: {e}")
32
+ else:
33
+ print("install.sh not found - trying direct pip fix...")
34
+
35
+ # Fallback: run key commands directly
36
+ try:
37
+ # Clean and reinstall triton
38
+ subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "triton", "triton_kernels"],
39
+ capture_output=True)
40
+
41
+ # Install nightly triton
42
+ subprocess.check_call([
43
+ sys.executable, "-m", "pip", "install", "--upgrade",
44
+ "--index-url", "https://download.pytorch.org/whl/nightly/cu121",
45
+ "triton"
46
+ ])
47
+
48
+ # Install triton_kernels
49
+ subprocess.check_call([
50
+ sys.executable, "-m", "pip", "install",
51
+ "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
52
+ ])
53
+
54
+ print("βœ“ Triton fixed via direct installation")
55
+ return True
56
+ except Exception as e:
57
+ print(f"Failed to fix Triton: {e}")
58
+ return False
59
 
60
+ # Auto-run on import
61
+ if __name__ != "__main__":
62
+ fix_triton_installation()