qingpei commited on
Commit
9a10e16
·
verified ·
1 Parent(s): cd61489

Update test_infer.py

Browse files
Files changed (1) hide show
  1. test_infer.py +5 -3
test_infer.py CHANGED
@@ -15,7 +15,8 @@ if __name__ == '__main__':
15
  attn_implementation="flash_attention_2",
16
  ).to("cuda")
17
 
18
- vision_path = "/input/zhangqinglong.zql/assets/"
 
19
 
20
  # qa
21
  # messages = [
@@ -71,12 +72,13 @@ if __name__ == '__main__':
71
  # },
72
  # ]
73
 
 
74
  messages = [
75
  {
76
  "role": "HUMAN",
77
  "content": [
78
  {"type": "text", "text": "Please recognize the language of this speech and transcribe it. Format: oral."},
79
- {"type": "audio", "audio": '/input/dongli.xq/BAC009S0915W0292.wav'},
80
  ],
81
  },
82
  ]
@@ -95,7 +97,7 @@ if __name__ == '__main__':
95
  for k in inputs.keys():
96
  if k == "pixel_values" or k == "pixel_values_videos" or k == "audio_feats":
97
  inputs[k] = inputs[k].to(dtype=torch.bfloat16)
98
-
99
  generated_ids = model.generate(
100
  **inputs,
101
  max_new_tokens=128,
 
15
  attn_implementation="flash_attention_2",
16
  ).to("cuda")
17
 
18
+ # replace with your model path
19
+ vision_path = environment.get("VISION_PATH", "") or os.path.join(os.path.dirname(__file__), "vision")
20
 
21
  # qa
22
  # messages = [
 
72
  # },
73
  # ]
74
 
75
+ # notice place the audio file in the same directory as the output file
76
  messages = [
77
  {
78
  "role": "HUMAN",
79
  "content": [
80
  {"type": "text", "text": "Please recognize the language of this speech and transcribe it. Format: oral."},
81
+ {"type": "audio", "audio": os.path.join(vision_path, "audio.wav")},
82
  ],
83
  },
84
  ]
 
97
  for k in inputs.keys():
98
  if k == "pixel_values" or k == "pixel_values_videos" or k == "audio_feats":
99
  inputs[k] = inputs[k].to(dtype=torch.bfloat16)
100
+
101
  generated_ids = model.generate(
102
  **inputs,
103
  max_new_tokens=128,