--- license: apache-2.0 library_name: transformers tags: [] --- # R1-AQA ## Introduction R1-AQA extends `Qwen2-Audio-7B-Instruc` by integrating group relative policy optimization (GRPO). This adaptation enhances the model's capacity for temporal reasoning and contextual alignment in audio question answering (AQA) tasks. For more details, please refer to our [Github](https://github.com/xiaomi/r1-aqa) and [Report](). ## Inference ```python import torch import torchaudio from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor # Load model model_name = "mispeech/r1-aqa" processor = AutoProcessor.from_pretrained(model_name) model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") # Load example audio wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" # from MMAU dataset waveform, sampling_rate = torchaudio.load(wav_path) assert sampling_rate == 16000 audios = [waveform.numpy()] # Make prompt text question = "Based on the given audio, identify the source of the speaking voice." options = ["Man", "Woman", "Child", "Robot"] prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in ." message = [ {"role": "user", "content": [ {"type": "audio", "audio_url": wav_path}, {"type": "text", "text": prompt} ]} ] texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False) # Process inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device) generated_ids = model.generate(**inputs, max_new_tokens=256) generated_ids = generated_ids[:, inputs.input_ids.size(1):] response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) print(response) ```