File size: 2,423 Bytes
6228595
f7e1fb5
 
6228595
 
 
 
f7e1fb5
6228595
f7e1fb5
 
6228595
f7e1fb5
6228595
f7e1fb5
 
6228595
 
f7e1fb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6228595
f7e1fb5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
setup.py - Run this at the start of app.py to ensure proper Triton installation
Add: import setup  # at the top of app.py after the docstring
"""

import subprocess
import sys
import os

def fix_triton_installation():
    """Fix Triton for MX format by running install.sh if needed."""
    try:
        # Check if we have the right triton
        import triton_kernels
        from triton.tools.ragged_tma import load_ragged, store_ragged
        print("✓ Triton already properly configured for MX format")
        return True
    except ImportError:
        print("Triton not properly configured for MX format")
        print("Running install.sh to fix dependencies...")
        
        # Check if install.sh exists
        if os.path.exists("install.sh"):
            try:
                # Make it executable and run it
                subprocess.check_call(["chmod", "+x", "install.sh"])
                subprocess.check_call(["./install.sh"])
                print("✓ Dependencies installed via install.sh")
                return True
            except subprocess.CalledProcessError as e:
                print(f"Error running install.sh: {e}")
        else:
            print("install.sh not found - trying direct pip fix...")
            
            # Fallback: run key commands directly
            try:
                # Clean and reinstall triton
                subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "triton", "triton_kernels"], 
                             capture_output=True)
                
                # Install nightly triton
                subprocess.check_call([
                    sys.executable, "-m", "pip", "install", "--upgrade",
                    "--index-url", "https://download.pytorch.org/whl/nightly/cu121",
                    "triton"
                ])
                
                # Install triton_kernels
                subprocess.check_call([
                    sys.executable, "-m", "pip", "install",
                    "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"
                ])
                
                print("✓ Triton fixed via direct installation")
                return True
            except Exception as e:
                print(f"Failed to fix Triton: {e}")
                return False

# Auto-run on import
if __name__ != "__main__":
    fix_triton_installation()