Spaces:
Running
on
Zero
Running
on
Zero
alex
commited on
Commit
·
7e16671
1
Parent(s):
2085d6e
flash attention upgraded for cuda 12.8
Browse files- app.py +21 -10
- requirements.txt +0 -1
app.py
CHANGED
@@ -24,20 +24,31 @@ importlib.invalidate_caches()
|
|
24 |
|
25 |
def sh(cmd): subprocess.check_call(cmd, shell=True)
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
sh(f"pip install {flash_attention_wheel}")
|
34 |
|
35 |
-
# tell Python to re-scan site-packages now that the egg-link exists
|
36 |
-
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
|
37 |
|
38 |
-
|
|
|
39 |
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
import torch.nn as nn
|
|
|
24 |
|
25 |
def sh(cmd): subprocess.check_call(cmd, shell=True)
|
26 |
|
27 |
+
flash_attention_installed = False
|
28 |
+
|
29 |
+
try:
|
30 |
+
print("Attempting to download and install FlashAttention wheel...")
|
31 |
+
flash_attention_wheel = hf_hub_download(
|
32 |
+
repo_id="alexnasa/flash-attn-3",
|
33 |
+
repo_type="model",
|
34 |
+
filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
|
35 |
+
)
|
36 |
|
37 |
+
sh(f"pip install {flash_attention_wheel}")
|
38 |
|
39 |
+
# tell Python to re-scan site-packages now that the egg-link exists
|
40 |
+
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
|
41 |
|
42 |
+
flash_attention_installed = True
|
43 |
+
print("FlashAttention installed successfully.")
|
44 |
|
45 |
+
except Exception as e:
|
46 |
+
print(f"⚠️ Could not install FlashAttention: {e}")
|
47 |
+
print("Continuing without FlashAttention...")
|
48 |
+
|
49 |
+
import torch
|
50 |
+
print(f"Torch version: {torch.__version__}")
|
51 |
+
print(f"FlashAttention available: {flash_attention_installed}")
|
52 |
|
53 |
|
54 |
import torch.nn as nn
|
requirements.txt
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
torch==2.7.1
|
2 |
tqdm
|
3 |
librosa==0.10.2.post1
|
4 |
peft==0.15.1
|
|
|
|
|
1 |
tqdm
|
2 |
librosa==0.10.2.post1
|
3 |
peft==0.15.1
|