Audio-Text-to-Text
Transformers
Safetensors
qwen2_audio
text2text-generation
Inference Endpoints
r1-aqa / README.md
jimbozhang's picture
Update README.md
2d1002f verified
|
raw
history blame
1.97 kB
metadata
license: apache-2.0
library_name: transformers
tags: []

R1-AQA

Introduction

R1-AQA extends Qwen2-Audio-7B-Instruc by integrating group relative policy optimization (GRPO). This adaptation enhances the model's capacity for temporal reasoning and contextual alignment in audio question answering (AQA) tasks.
For more details, please refer to our Github and Report.

Inference

import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor

# Load model
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

# Load example audio
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav"  # from MMAU dataset
waveform, sampling_rate = torchaudio.load(wav_path)
assert sampling_rate == 16000
audios = [waveform.numpy()]

# Make prompt text
question = "Based on the given audio, identify the source of the speaking voice."
options = ["Man", "Woman", "Child", "Robot"]
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
message = [
    {"role": "user", "content": [
        {"type": "audio", "audio_url": wav_path},
        {"type": "text", "text": prompt}
    ]}
]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)

# Process
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

print(response)