Audio-Text-to-Text
Transformers
Safetensors
qwen2_audio
text2text-generation
Inference Endpoints
r1-aqa / README.md
franken
Update README.md
4011eae verified
|
raw
history blame
1.96 kB
metadata
license: apache-2.0
library_name: transformers
tags: []

R1-AQA

Introduction

R1-AQA is based on Qwen2-Audio-7B-Instruc, but applied group relative policy optimization (GRPO) algorithm to the Audio Question Answering(AQA) task.
For more details, please refer to our Github and Report.

Inference

import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor


def _get_audio(wav_path):
    waveform, sample_rate = torchaudio.load(wav_path)
    if sample_rate != 16000:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    audio = waveform[0]
    return audio

model_name = "mispeech/r1-aqa"
audio_url = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav"  # Copyied from MMAU dataset

processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

message = [
    {"role": "user", "content": [
        {"type": "audio", "audio_url": audio_url},
        {"type": "text", "text": "Based on the given audio, identify the source of the speaking voice. Please choose the answer from the following options: ['Man', 'Woman', 'Child', 'Robot']. Output the final answer in <answer> </answer>."}
    ]}
]

audios = [_get_audio(audio_url).numpy()]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)

inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)

generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(f"response:{response}")