import gradio as gr
#
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio
import torchaudio.transforms as T
import logging

import json

import importlib 
modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT")

from Prediction_Head.MTGGenre_head import MLPProberBase 
# input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py


logger = logging.getLogger("whisper-jax-app")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
ch.setFormatter(formatter)
logger.addHandler(ch)



inputs = [
    gr.components.Audio(type="filepath", label="Add music audio file"), 
    gr.inputs.Audio(source="microphone", type="filepath"),
]
live_inputs = [
    gr.Audio(source="microphone",streaming=True, type="filepath"),
]
# outputs = [gr.components.Textbox()]
# outputs = [gr.components.Textbox(), transcription_df]
title = "Predict the top 5 possible genres and tags of Music"
description = "An example of using map/MERT-95M-public model as backbone to conduct music genre/tagging predcition."
article = ""
audio_examples = [
    # ["input/example-1.wav"],
    # ["input/example-2.wav"],
]

# Load the model and the corresponding preprocessor config
# model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
# processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public")
processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public")

MERT_LAYER_IDX = 7
MTGGenre_classifier = MLPProberBase()
MTGGenre_classifier.load_state_dict(torch.load('Prediction_Head/best_MTGGenre.ckpt')['state_dict'])

with open('Prediction_Head/MTGGenre_id2class.json', 'r') as f:
   id2cls=json.load(f)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
MTGGenre_classifier.to(device)

def convert_audio(inputs, microphone):
    if (microphone is not None):
        inputs = microphone
    
    waveform, sample_rate = torchaudio.load(inputs)

    resample_rate = processor.sampling_rate

    # make sure the sample_rate aligned
    if resample_rate != sample_rate:
        print(f'setting rate from {sample_rate} to {resample_rate}')
        resampler = T.Resample(sample_rate, resample_rate)
        waveform = resampler(waveform)
    
    waveform = waveform.view(-1,) # make it (n_sample, )
    model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
    model_inputs.to(device)
    with torch.no_grad():
        model_outputs = model(**model_inputs, output_hidden_states=True)

    # take a look at the output shape, there are 13 layers of representation
    # each layer performs differently in different downstream tasks, you should choose empirically
    all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
    print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]

    logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87]
    print(logits.shape)
    sorted_idx = torch.argsort(logits, dim = -1, descending=True)
    
    output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]])
    # logger.warning(all_layer_hidden_states.shape)
    
    # return f"device {device}, sample reprensentation:  {str(all_layer_hidden_states[12, 0, :10])}"
    return f"device: {device}\n" + output_texts

def live_convert_audio(microphone):
    if (microphone is not None):
        inputs = microphone
    
    waveform, sample_rate = torchaudio.load(inputs)

    resample_rate = processor.sampling_rate

    # make sure the sample_rate aligned
    if resample_rate != sample_rate:
        print(f'setting rate from {sample_rate} to {resample_rate}')
        resampler = T.Resample(sample_rate, resample_rate)
        waveform = resampler(waveform)
    
    waveform = waveform.view(-1,) # make it (n_sample, )
    model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
    model_inputs.to(device)
    with torch.no_grad():
        model_outputs = model(**model_inputs, output_hidden_states=True)

    # take a look at the output shape, there are 13 layers of representation
    # each layer performs differently in different downstream tasks, you should choose empirically
    all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
    print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]

    logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) # [1, 87]
    print(logits.shape)
    sorted_idx = torch.argsort(logits, dim = -1, descending=True)
    
    output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]])
    # logger.warning(all_layer_hidden_states.shape)
    
    # return f"device {device}, sample reprensentation:  {str(all_layer_hidden_states[12, 0, :10])}"
    return f"device: {device}\n" + output_texts


audio_chunked = gr.Interface(
    fn=convert_audio,
    inputs=inputs,
    outputs=[gr.components.Textbox()],
    allow_flagging="never",
    title=title,
    description=description,
    article=article,
    examples=audio_examples,
)

live_audio_chunked = gr.Interface(
    fn=live_convert_audio,
    inputs=live_inputs,
    outputs=[gr.components.Textbox()],
    allow_flagging="never",
    title=title,
    description=description,
    article=article,
    # examples=audio_examples,
    live=True,
)


demo = gr.Blocks()
with demo:
    gr.TabbedInterface(
        [
            audio_chunked,
            live_audio_chunked,
        ], 
        [
            "Audio File or Recording",
            "Live Streaming Music"
        ]
    )
demo.queue(concurrency_count=1, max_size=5)
demo.launch(show_api=False)