Audio-Text-to-Text
Transformers
Safetensors
qwen2_audio
text2text-generation
File size: 1,935 Bytes
7662fdc
26fcbc8
7662fdc
 
 
 
26fcbc8
7662fdc
 
 
4011eae
 
26598db
4011eae
7662fdc
 
26fcbc8
 
 
 
 
7662fdc
b74b317
26fcbc8
 
 
7662fdc
b74b317
 
 
 
 
cefa606
b74b317
 
 
26fcbc8
 
b74b317
 
26fcbc8
 
cefa606
7662fdc
b74b317
26fcbc8
 
 
 
b74b317
 
26fcbc8
7662fdc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
---
license: apache-2.0
library_name: transformers
tags: []
---

# R1-AQA 

<!-- Provide a quick summary of what the model is/does. -->

## Introduction

R1-AQA extends `Qwen2-Audio-7B-Instruc` by integrating group relative policy optimization (GRPO). This adaptation enhances the model's capacity for temporal reasoning and contextual alignment in audio question answering (AQA) tasks.  
For more details, please refer to our [Github](https://github.com/xiaomi/r1-aqa) and [Report]().


## Inference
```python
import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor

# Load model
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

# Load example audio
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav"  # from MMAU dataset
waveform, _ = torchaudio.load(wav_path)  # 16KHz
audios = [waveform[0].numpy()]

# Make prompt text
question = "Based on the given audio, identify the source of the speaking voice."
options = ["Man", "Woman", "Child", "Robot"]
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
message = [
    {"role": "user", "content": [
        {"type": "audio", "audio_url": wav_path},
        {"type": "text", "text": prompt}
    ]}
]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)

# Process
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

print(response)
```