alex commited on
Commit
7e16671
·
1 Parent(s): 2085d6e

flash attention upgraded for cuda 12.8

Browse files
Files changed (2) hide show
  1. app.py +21 -10
  2. requirements.txt +0 -1
app.py CHANGED
@@ -24,20 +24,31 @@ importlib.invalidate_caches()
24
 
25
  def sh(cmd): subprocess.check_call(cmd, shell=True)
26
 
27
- flash_attention_wheel = hf_hub_download(
28
- repo_id="alexnasa/flash-attn-3",
29
- repo_type="model",
30
- filename="flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
31
- )
 
 
 
 
32
 
33
- sh(f"pip install {flash_attention_wheel}")
34
 
35
- # tell Python to re-scan site-packages now that the egg-link exists
36
- import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
37
 
38
- import torch
 
39
 
40
- print(f'torch version:{torch.__version__}')
 
 
 
 
 
 
41
 
42
 
43
  import torch.nn as nn
 
24
 
25
  def sh(cmd): subprocess.check_call(cmd, shell=True)
26
 
27
+ flash_attention_installed = False
28
+
29
+ try:
30
+ print("Attempting to download and install FlashAttention wheel...")
31
+ flash_attention_wheel = hf_hub_download(
32
+ repo_id="alexnasa/flash-attn-3",
33
+ repo_type="model",
34
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
35
+ )
36
 
37
+ sh(f"pip install {flash_attention_wheel}")
38
 
39
+ # tell Python to re-scan site-packages now that the egg-link exists
40
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
41
 
42
+ flash_attention_installed = True
43
+ print("FlashAttention installed successfully.")
44
 
45
+ except Exception as e:
46
+ print(f"⚠️ Could not install FlashAttention: {e}")
47
+ print("Continuing without FlashAttention...")
48
+
49
+ import torch
50
+ print(f"Torch version: {torch.__version__}")
51
+ print(f"FlashAttention available: {flash_attention_installed}")
52
 
53
 
54
  import torch.nn as nn
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- torch==2.7.1
2
  tqdm
3
  librosa==0.10.2.post1
4
  peft==0.15.1
 
 
1
  tqdm
2
  librosa==0.10.2.post1
3
  peft==0.15.1