srinuksv commited on
Commit
cc17f44
Β·
verified Β·
1 Parent(s): bc163bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -16,23 +16,35 @@ from qwen_omni_utils import process_mm_info
16
  from argparse import ArgumentParser
17
 
18
  def _load_model_processor(args):
 
 
19
  if args.cpu_only:
20
  device_map = 'cpu'
 
21
  else:
22
  device_map = 'auto'
 
23
 
24
  # Check if flash-attn2 flag is enabled and load model accordingly
25
  if args.flash_attn2:
26
- model = Qwen2_5OmniModel.from_pretrained(args.checkpoint_path,
27
- torch_dtype='auto',
28
- attn_implementation='flash_attention_2',
29
- device_map=device_map)
 
 
 
30
  else:
31
- model = Qwen2_5OmniModel.from_pretrained(args.checkpoint_path, device_map=device_map)
 
 
 
 
32
 
33
  processor = Qwen2_5OmniProcessor.from_pretrained(args.checkpoint_path)
34
  return model, processor
35
 
 
36
  def _launch_demo(args, model, processor):
37
  # Voice settings
38
  VOICE_LIST = ['Chelsie', 'Ethan']
 
16
  from argparse import ArgumentParser
17
 
18
  def _load_model_processor(args):
19
+ import torch
20
+
21
  if args.cpu_only:
22
  device_map = 'cpu'
23
+ max_memory = {0: "2GB"} # Limit memory usage when running on CPU
24
  else:
25
  device_map = 'auto'
26
+ max_memory = {i: "20GB" for i in range(torch.cuda.device_count())} # Adjust as needed
27
 
28
  # Check if flash-attn2 flag is enabled and load model accordingly
29
  if args.flash_attn2:
30
+ model = Qwen2_5OmniModel.from_pretrained(
31
+ args.checkpoint_path,
32
+ torch_dtype='auto',
33
+ attn_implementation='flash_attention_2',
34
+ device_map=device_map,
35
+ max_memory=max_memory
36
+ )
37
  else:
38
+ model = Qwen2_5OmniModel.from_pretrained(
39
+ args.checkpoint_path,
40
+ device_map=device_map,
41
+ max_memory=max_memory
42
+ )
43
 
44
  processor = Qwen2_5OmniProcessor.from_pretrained(args.checkpoint_path)
45
  return model, processor
46
 
47
+
48
  def _launch_demo(args, model, processor):
49
  # Voice settings
50
  VOICE_LIST = ['Chelsie', 'Ethan']