Audio-Text-to-Text
Transformers
Safetensors
qwen2_audio
text2text-generation
Inference Endpoints
jimbozhang commited on
Commit
b74b317
·
verified ·
1 Parent(s): 26598db

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +15 -17
README.md CHANGED
@@ -20,35 +20,33 @@ import torch
20
  import torchaudio
21
  from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
22
 
23
-
24
- def _get_audio(wav_path):
25
- waveform, sample_rate = torchaudio.load(wav_path)
26
- if sample_rate != 16000:
27
- waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
28
- audio = waveform[0]
29
- return audio
30
-
31
  model_name = "mispeech/r1-aqa"
32
- audio_url = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" # Copyied from MMAU dataset
33
-
34
  processor = AutoProcessor.from_pretrained(model_name)
35
  model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
36
 
 
 
 
 
 
 
 
 
 
37
  message = [
38
  {"role": "user", "content": [
39
- {"type": "audio", "audio_url": audio_url},
40
- {"type": "text", "text": "Based on the given audio, identify the source of the speaking voice. Please choose the answer from the following options: ['Man', 'Woman', 'Child', 'Robot']. Output the final answer in <answer> </answer>."}
41
  ]}
42
  ]
43
 
44
- audios = [_get_audio(audio_url).numpy()]
45
- texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
46
-
47
  inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
48
-
49
  generated_ids = model.generate(**inputs, max_new_tokens=256)
50
  generated_ids = generated_ids[:, inputs.input_ids.size(1):]
51
  response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
52
- print(f"response:{response}")
 
53
  ```
54
 
 
20
  import torchaudio
21
  from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
22
 
23
+ # Load model
 
 
 
 
 
 
 
24
  model_name = "mispeech/r1-aqa"
 
 
25
  processor = AutoProcessor.from_pretrained(model_name)
26
  model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
27
 
28
+ # Load example audio
29
+ wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" # from MMAU dataset
30
+ waveform, _ = torchaudio.load(wav_path) # 16KHz
31
+ audios = [waveform[0].numpy()]
32
+
33
+ # Make prompt
34
+ question = "Based on the given audio, identify the source of the speaking voice."
35
+ options = ["Man", "Woman", "Child", "Robot"]
36
+ prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
37
  message = [
38
  {"role": "user", "content": [
39
+ {"type": "audio", "audio_url": wav_path},
40
+ {"type": "text", "text": prompt}
41
  ]}
42
  ]
43
 
44
+ # Process
 
 
45
  inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
 
46
  generated_ids = model.generate(**inputs, max_new_tokens=256)
47
  generated_ids = generated_ids[:, inputs.input_ids.size(1):]
48
  response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
49
+
50
+ print(response)
51
  ```
52