|
--- |
|
license: apache-2.0 |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
# R1-AQA |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
|
|
|
|
## Inference |
|
```python |
|
import torch |
|
import torchaudio |
|
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor |
|
|
|
|
|
def _get_audio(wav_path): |
|
waveform, sample_rate = torchaudio.load(wav_path) |
|
if sample_rate != 16000: |
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) |
|
audio = waveform[0] |
|
return audio |
|
|
|
model_name = "mispeech/r1-aqa" |
|
audio_url = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" |
|
|
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") |
|
|
|
message = [ |
|
{"role": "user", "content": [ |
|
{"type": "audio", "audio_url": audio_url}, |
|
{"type": "text", "text": "Based on the given audio, identify the source of the speaking voice. Please choose the answer from the following options: ['Man', 'Woman', 'Child', 'Robot']. Output the final answer in <answer> </answer>."} |
|
]} |
|
] |
|
|
|
audios = [_get_audio(audio_url).numpy()] |
|
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False) |
|
|
|
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device) |
|
|
|
generated_ids = model.generate(**inputs, max_new_tokens=256) |
|
generated_ids = generated_ids[:, inputs.input_ids.size(1):] |
|
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
print(f"response:{response}") |
|
``` |
|
|
|
|