Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
| # Load a smaller Wav2Vec model and processor for Persian | |
| model_name = "facebook/wav2vec2-base" # Smaller model | |
| processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
| def transcribe_audio(audio): | |
| # Load the audio file and resample to 16kHz | |
| waveform, sample_rate = torchaudio.load(audio) | |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
| waveform = resampler(waveform) | |
| # Preprocess the audio | |
| input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values | |
| # Perform inference | |
| with torch.no_grad(): | |
| logits = model(input_values).logits | |
| # Decode the logits to text | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.decode(predicted_ids[0]) | |
| return transcription | |
| with gr.Blocks(fill_height=True) as demo: | |
| with gr.Sidebar(): | |
| gr.Markdown("# Inference Provider") | |
| gr.Markdown("This Space showcases the google/gemma-2-2b-it model, served by the nebius API. Sign in with your Hugging Face account to use this API.") | |
| button = gr.LoginButton("Sign in") | |
| with gr.Tab("Text Inference"): | |
| gr.load("models/google/gemma-2-2b-it", accept_token=button, provider="nebius") | |
| with gr.Tab("Persian ASR"): | |
| audio_input = gr.Audio(label="Upload Audio", type="filepath") | |
| text_output = gr.Textbox(label="Transcription") | |
| transcribe_button = gr.Button("Transcribe") | |
| transcribe_button.click(transcribe_audio, inputs=audio_input, outputs=text_output) | |
| demo.launch() |