metadata
license: apache-2.0
library_name: transformers
tags: []
R1-AQA
Introduction
R1-AQA is based on Qwen2-Audio-7B-Instruc
, but applied group relative policy optimization (GRPO) algorithm to the Audio Question Answering(AQA) task.
For more details, please refer to our Github and Report.
Inference
import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
def _get_audio(wav_path):
waveform, sample_rate = torchaudio.load(wav_path)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
audio = waveform[0]
return audio
model_name = "mispeech/r1-aqa"
audio_url = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" # Copyied from MMAU dataset
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
message = [
{"role": "user", "content": [
{"type": "audio", "audio_url": audio_url},
{"type": "text", "text": "Based on the given audio, identify the source of the speaking voice. Please choose the answer from the following options: ['Man', 'Woman', 'Child', 'Robot']. Output the final answer in <answer> </answer>."}
]}
]
audios = [_get_audio(audio_url).numpy()]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(f"response:{response}")