File size: 1,839 Bytes
7662fdc 26fcbc8 7662fdc 26fcbc8 7662fdc 4011eae 26598db 4011eae 7662fdc 26fcbc8 7662fdc b74b317 26fcbc8 7662fdc b74b317 26fcbc8 b74b317 26fcbc8 7662fdc b74b317 26fcbc8 b74b317 26fcbc8 7662fdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
---
license: apache-2.0
library_name: transformers
tags: []
---
# R1-AQA
<!-- Provide a quick summary of what the model is/does. -->
## Introduction
R1-AQA extends `Qwen2-Audio-7B-Instruc` by integrating group relative policy optimization (GRPO). This adaptation enhances the model's capacity for temporal reasoning and contextual alignment in audio question answering (AQA) tasks.
For more details, please refer to our [Github](https://github.com/xiaomi/r1-aqa) and [Report]().
## Inference
```python
import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
# Load model
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
# Load example audio
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" # from MMAU dataset
waveform, _ = torchaudio.load(wav_path) # 16KHz
audios = [waveform[0].numpy()]
# Make prompt
question = "Based on the given audio, identify the source of the speaking voice."
options = ["Man", "Woman", "Child", "Robot"]
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
message = [
{"role": "user", "content": [
{"type": "audio", "audio_url": wav_path},
{"type": "text", "text": prompt}
]}
]
# Process
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(response)
```
|