Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,423 Bytes
6228595 f7e1fb5 6228595 f7e1fb5 6228595 f7e1fb5 6228595 f7e1fb5 6228595 f7e1fb5 6228595 f7e1fb5 6228595 f7e1fb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
"""
setup.py - Run this at the start of app.py to ensure proper Triton installation
Add: import setup # at the top of app.py after the docstring
"""
import subprocess
import sys
import os
def fix_triton_installation():
"""Fix Triton for MX format by running install.sh if needed."""
try:
# Check if we have the right triton
import triton_kernels
from triton.tools.ragged_tma import load_ragged, store_ragged
print("✓ Triton already properly configured for MX format")
return True
except ImportError:
print("Triton not properly configured for MX format")
print("Running install.sh to fix dependencies...")
# Check if install.sh exists
if os.path.exists("install.sh"):
try:
# Make it executable and run it
subprocess.check_call(["chmod", "+x", "install.sh"])
subprocess.check_call(["./install.sh"])
print("✓ Dependencies installed via install.sh")
return True
except subprocess.CalledProcessError as e:
print(f"Error running install.sh: {e}")
else:
print("install.sh not found - trying direct pip fix...")
# Fallback: run key commands directly
try:
# Clean and reinstall triton
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "triton", "triton_kernels"],
capture_output=True)
# Install nightly triton
subprocess.check_call([
sys.executable, "-m", "pip", "install", "--upgrade",
"--index-url", "https://download.pytorch.org/whl/nightly/cu121",
"triton"
])
# Install triton_kernels
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
])
print("✓ Triton fixed via direct installation")
return True
except Exception as e:
print(f"Failed to fix Triton: {e}")
return False
# Auto-run on import
if __name__ != "__main__":
fix_triton_installation() |