Audio-Text-to-Text
Transformers
Safetensors
qwen2_audio
text2text-generation
r1-aqa / README.md
jimbozhang's picture
Update README.md
b74b317 verified
|
raw
history blame
1.84 kB
---
license: apache-2.0
library_name: transformers
tags: []
---
# R1-AQA
<!-- Provide a quick summary of what the model is/does. -->
## Introduction
R1-AQA extends `Qwen2-Audio-7B-Instruc` by integrating group relative policy optimization (GRPO). This adaptation enhances the model's capacity for temporal reasoning and contextual alignment in audio question answering (AQA) tasks.
For more details, please refer to our [Github](https://github.com/xiaomi/r1-aqa) and [Report]().
## Inference
```python
import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
# Load model
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
# Load example audio
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" # from MMAU dataset
waveform, _ = torchaudio.load(wav_path) # 16KHz
audios = [waveform[0].numpy()]
# Make prompt
question = "Based on the given audio, identify the source of the speaking voice."
options = ["Man", "Woman", "Child", "Robot"]
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
message = [
{"role": "user", "content": [
{"type": "audio", "audio_url": wav_path},
{"type": "text", "text": prompt}
]}
]
# Process
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(response)
```