Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
Β·
f7e1fb5
1
Parent(s):
1189725
yes
Browse files- app.py +67 -49
- install.sh +14 -5
- setup.py +51 -20
app.py
CHANGED
@@ -16,40 +16,24 @@ gradio>=5.42.0
|
|
16 |
triton>=3.4.0
|
17 |
git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
|
18 |
"""
|
|
|
19 |
|
20 |
-
#
|
21 |
-
|
22 |
-
import
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
try:
|
27 |
import triton_kernels
|
28 |
-
|
29 |
-
|
30 |
except ImportError:
|
31 |
-
print("
|
32 |
-
try:
|
33 |
-
subprocess.check_call([
|
34 |
-
sys.executable, "-m", "pip", "install",
|
35 |
-
"git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
|
36 |
-
])
|
37 |
-
print("β triton_kernels installed successfully")
|
38 |
-
# Force reimport
|
39 |
-
import importlib
|
40 |
-
import site
|
41 |
-
importlib.reload(site)
|
42 |
-
return True
|
43 |
-
except subprocess.CalledProcessError as e:
|
44 |
-
print(f"β Failed to install triton_kernels: {e}")
|
45 |
-
print("ERROR: MX format will NOT work properly without triton_kernels!")
|
46 |
-
return False
|
47 |
-
|
48 |
-
# Install triton_kernels before other imports
|
49 |
-
_TRITON_INSTALL_SUCCESS = ensure_triton_kernels()
|
50 |
|
51 |
# ===== MAIN IMPORTS =====
|
52 |
import os, gc, json, torch, warnings, traceback
|
|
|
53 |
from dataclasses import dataclass
|
54 |
from typing import List, Dict, Optional, Any, Union
|
55 |
from datetime import datetime
|
@@ -90,14 +74,17 @@ except Exception:
|
|
90 |
_HAS_PEFT = False
|
91 |
print("β PEFT not available. Install with: pip install peft")
|
92 |
|
93 |
-
# Check for triton_kernels
|
94 |
try:
|
95 |
import triton_kernels
|
|
|
|
|
96 |
_HAS_TRITON_KERNELS = True
|
97 |
-
print("β triton_kernels loaded - MX format enabled")
|
98 |
-
except ImportError:
|
99 |
_HAS_TRITON_KERNELS = False
|
100 |
-
print("β triton_kernels not
|
|
|
101 |
|
102 |
# ===== CONFIGURATION =====
|
103 |
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
|
@@ -202,33 +189,64 @@ def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
|
|
202 |
if IS_GPT_OSS:
|
203 |
if _HAS_TRITON_KERNELS:
|
204 |
print("β Loading with native MX format support")
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
206 |
else:
|
207 |
print("β No triton_kernels - falling back to bf16 (dequantized)")
|
208 |
print(" This will likely cause LoRA compatibility issues!")
|
209 |
load_kwargs["torch_dtype"] = torch.bfloat16
|
|
|
|
|
|
|
|
|
210 |
else:
|
211 |
# Non-GPT-OSS models
|
212 |
load_kwargs["torch_dtype"] = torch.bfloat16
|
213 |
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
else:
|
224 |
-
|
225 |
-
|
226 |
-
# Set model config
|
227 |
-
if getattr(model.config, "pad_token_id", None) is None:
|
228 |
-
model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
229 |
-
model.config.use_cache = True
|
230 |
-
|
231 |
-
return model
|
232 |
|
233 |
def load_lora_adapter(model, adapter_id: str, subfolder: Optional[str] = None):
|
234 |
"""Load and attach LoRA adapter with MX format handling."""
|
|
|
16 |
triton>=3.4.0
|
17 |
git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
|
18 |
"""
|
19 |
+
from __future__ import annotations
|
20 |
|
21 |
+
# Import setup to fix Triton if needed
|
22 |
+
try:
|
23 |
+
import setup # This will run install.sh if Triton needs fixing
|
24 |
+
except ImportError:
|
25 |
+
print("No setup.py found - checking Triton manually")
|
26 |
+
# Fallback check
|
27 |
try:
|
28 |
import triton_kernels
|
29 |
+
from triton.tools.ragged_tma import load_ragged
|
30 |
+
print("β Triton configured correctly")
|
31 |
except ImportError:
|
32 |
+
print("β Triton not configured for MX - run install.sh")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
# ===== MAIN IMPORTS =====
|
35 |
import os, gc, json, torch, warnings, traceback
|
36 |
+
import subprocess, sys
|
37 |
from dataclasses import dataclass
|
38 |
from typing import List, Dict, Optional, Any, Union
|
39 |
from datetime import datetime
|
|
|
74 |
_HAS_PEFT = False
|
75 |
print("β PEFT not available. Install with: pip install peft")
|
76 |
|
77 |
+
# Check for triton_kernels after setup
|
78 |
try:
|
79 |
import triton_kernels
|
80 |
+
# Also check for the specific module that was missing
|
81 |
+
from triton.tools.ragged_tma import load_ragged, store_ragged
|
82 |
_HAS_TRITON_KERNELS = True
|
83 |
+
print("β triton_kernels loaded with ragged_tma support - MX format enabled")
|
84 |
+
except ImportError as e:
|
85 |
_HAS_TRITON_KERNELS = False
|
86 |
+
print(f"β triton_kernels not fully functional: {e}")
|
87 |
+
print("MX format will fall back to bf16 - LoRA may not work correctly")
|
88 |
|
89 |
# ===== CONFIGURATION =====
|
90 |
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
|
|
|
189 |
if IS_GPT_OSS:
|
190 |
if _HAS_TRITON_KERNELS:
|
191 |
print("β Loading with native MX format support")
|
192 |
+
# For MX format, let the model handle its own dtype
|
193 |
+
load_kwargs["torch_dtype"] = "auto"
|
194 |
+
|
195 |
+
# Set environment variable to ensure MX is used
|
196 |
+
import os
|
197 |
+
os.environ["FORCE_MX_QUANTIZATION"] = "1"
|
198 |
else:
|
199 |
print("β No triton_kernels - falling back to bf16 (dequantized)")
|
200 |
print(" This will likely cause LoRA compatibility issues!")
|
201 |
load_kwargs["torch_dtype"] = torch.bfloat16
|
202 |
+
|
203 |
+
# Explicitly disable MX
|
204 |
+
import os
|
205 |
+
os.environ["FORCE_MX_QUANTIZATION"] = "0"
|
206 |
else:
|
207 |
# Non-GPT-OSS models
|
208 |
load_kwargs["torch_dtype"] = torch.bfloat16
|
209 |
|
210 |
+
try:
|
211 |
+
# Load the model
|
212 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
|
213 |
+
|
214 |
+
# Verify format
|
215 |
+
print(f"Model loaded - dtype: {next(model.parameters()).dtype}")
|
216 |
+
if IS_GPT_OSS:
|
217 |
+
is_mx = detect_mx_format(model)
|
218 |
+
if is_mx:
|
219 |
+
print("β Confirmed: Using native MX format")
|
220 |
+
else:
|
221 |
+
print("β Model dequantized to bf16 - LoRA may fail")
|
222 |
+
|
223 |
+
# Set model config
|
224 |
+
if getattr(model.config, "pad_token_id", None) is None:
|
225 |
+
model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
226 |
+
model.config.use_cache = True
|
227 |
+
|
228 |
+
return model
|
229 |
+
|
230 |
+
except Exception as e:
|
231 |
+
if "ragged_tma" in str(e):
|
232 |
+
print("\n" + "="*60)
|
233 |
+
print("ERROR: Triton version incompatibility detected!")
|
234 |
+
print("The model requires a specific Triton version with ragged_tma support.")
|
235 |
+
print("\nTo fix this, run:")
|
236 |
+
print("pip uninstall -y triton triton_kernels")
|
237 |
+
print("pip install --index-url https://download.pytorch.org/whl/nightly/cu121 triton")
|
238 |
+
print("pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels")
|
239 |
+
print("="*60 + "\n")
|
240 |
+
|
241 |
+
# Try to load without MX as fallback
|
242 |
+
print("Attempting to load model without MX format...")
|
243 |
+
load_kwargs["torch_dtype"] = torch.bfloat16
|
244 |
+
os.environ["FORCE_MX_QUANTIZATION"] = "0"
|
245 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
|
246 |
+
print("β Model loaded in bf16 mode (degraded performance)")
|
247 |
+
return model
|
248 |
else:
|
249 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
def load_lora_adapter(model, adapter_id: str, subfolder: Optional[str] = None):
|
252 |
"""Load and attach LoRA adapter with MX format handling."""
|
install.sh
CHANGED
@@ -10,7 +10,6 @@ pip install --upgrade pip
|
|
10 |
pip install huggingface_hub>=0.34.0
|
11 |
pip install transformers>=4.55.0
|
12 |
pip install accelerate>=0.33.0
|
13 |
-
pip install torch>=2.4.0
|
14 |
pip install gradio>=5.42.0
|
15 |
pip install spaces
|
16 |
|
@@ -21,11 +20,21 @@ pip install bitsandbytes>=0.43.1
|
|
21 |
# Install Harmony format
|
22 |
pip install openai-harmony
|
23 |
|
24 |
-
#
|
25 |
-
|
|
|
26 |
|
27 |
-
#
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
echo "Installing triton_kernels (REQUIRED for MX format)..."
|
30 |
pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
|
31 |
|
|
|
10 |
pip install huggingface_hub>=0.34.0
|
11 |
pip install transformers>=4.55.0
|
12 |
pip install accelerate>=0.33.0
|
|
|
13 |
pip install gradio>=5.42.0
|
14 |
pip install spaces
|
15 |
|
|
|
20 |
# Install Harmony format
|
21 |
pip install openai-harmony
|
22 |
|
23 |
+
# FIX TRITON FOR MX FORMAT
|
24 |
+
# The standard triton doesn't have ragged_tma module needed for MX
|
25 |
+
echo "Fixing Triton installation for MX format..."
|
26 |
|
27 |
+
# Clean existing triton installations
|
28 |
+
pip uninstall -y triton triton_kernels 2>/dev/null || true
|
29 |
+
|
30 |
+
# Install PyTorch nightly (includes compatible Triton)
|
31 |
+
echo "Installing PyTorch nightly with compatible Triton..."
|
32 |
+
pip install --upgrade --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
|
33 |
+
|
34 |
+
# Install Triton from PyTorch nightly
|
35 |
+
pip install --upgrade --index-url https://download.pytorch.org/whl/nightly/cu121 triton
|
36 |
+
|
37 |
+
# Install triton_kernels from source
|
38 |
echo "Installing triton_kernels (REQUIRED for MX format)..."
|
39 |
pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
|
40 |
|
setup.py
CHANGED
@@ -1,31 +1,62 @@
|
|
1 |
"""
|
2 |
-
setup.py - Run this at the start of app.py to ensure
|
3 |
-
Add
|
4 |
"""
|
5 |
|
6 |
import subprocess
|
7 |
import sys
|
|
|
8 |
|
9 |
-
def
|
10 |
-
"""
|
11 |
try:
|
|
|
12 |
import triton_kernels
|
13 |
-
|
|
|
14 |
return True
|
15 |
except ImportError:
|
16 |
-
print("
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
#
|
30 |
-
if __name__ != "__main__":
|
31 |
-
|
|
|
1 |
"""
|
2 |
+
setup.py - Run this at the start of app.py to ensure proper Triton installation
|
3 |
+
Add: import setup # at the top of app.py after the docstring
|
4 |
"""
|
5 |
|
6 |
import subprocess
|
7 |
import sys
|
8 |
+
import os
|
9 |
|
10 |
+
def fix_triton_installation():
|
11 |
+
"""Fix Triton for MX format by running install.sh if needed."""
|
12 |
try:
|
13 |
+
# Check if we have the right triton
|
14 |
import triton_kernels
|
15 |
+
from triton.tools.ragged_tma import load_ragged, store_ragged
|
16 |
+
print("β Triton already properly configured for MX format")
|
17 |
return True
|
18 |
except ImportError:
|
19 |
+
print("Triton not properly configured for MX format")
|
20 |
+
print("Running install.sh to fix dependencies...")
|
21 |
+
|
22 |
+
# Check if install.sh exists
|
23 |
+
if os.path.exists("install.sh"):
|
24 |
+
try:
|
25 |
+
# Make it executable and run it
|
26 |
+
subprocess.check_call(["chmod", "+x", "install.sh"])
|
27 |
+
subprocess.check_call(["./install.sh"])
|
28 |
+
print("β Dependencies installed via install.sh")
|
29 |
+
return True
|
30 |
+
except subprocess.CalledProcessError as e:
|
31 |
+
print(f"Error running install.sh: {e}")
|
32 |
+
else:
|
33 |
+
print("install.sh not found - trying direct pip fix...")
|
34 |
+
|
35 |
+
# Fallback: run key commands directly
|
36 |
+
try:
|
37 |
+
# Clean and reinstall triton
|
38 |
+
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "triton", "triton_kernels"],
|
39 |
+
capture_output=True)
|
40 |
+
|
41 |
+
# Install nightly triton
|
42 |
+
subprocess.check_call([
|
43 |
+
sys.executable, "-m", "pip", "install", "--upgrade",
|
44 |
+
"--index-url", "https://download.pytorch.org/whl/nightly/cu121",
|
45 |
+
"triton"
|
46 |
+
])
|
47 |
+
|
48 |
+
# Install triton_kernels
|
49 |
+
subprocess.check_call([
|
50 |
+
sys.executable, "-m", "pip", "install",
|
51 |
+
"git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
|
52 |
+
])
|
53 |
+
|
54 |
+
print("β Triton fixed via direct installation")
|
55 |
+
return True
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Failed to fix Triton: {e}")
|
58 |
+
return False
|
59 |
|
60 |
+
# Auto-run on import
|
61 |
+
if __name__ != "__main__":
|
62 |
+
fix_triton_installation()
|