import streamlit as st from ced_model.feature_extraction_ced import CedFeatureExtractor from ced_model.modeling_ced import CedForAudioClassification import torchaudio import torch import os import soundfile as sf model_name = "mispeech/ced-tiny" feature_extractor = CedFeatureExtractor.from_pretrained(model_name) model = CedForAudioClassification.from_pretrained(model_name) st.title("Audio Classification App") st.subheader("Trained on 50 classes of ESC 50 dataset") st.write("Upload an audio file to predict its class.") audio_file = st.file_uploader("Upload Audio File", type=["wav"]) if audio_file is not None: st.write(f"Uploaded file: {audio_file.name}") try: temp_file_path = "temp.wav" with open(temp_file_path, "wb") as f: f.write(audio_file.read()) try: audio, sampling_rate = torchaudio.load(temp_file_path) except Exception: st.warning("Fallback to soundfile for audio loading.") audio_data, sampling_rate = sf.read(temp_file_path) audio = torch.tensor(audio_data).unsqueeze(0) if sampling_rate != 16000: st.warning("Resampling audio to 16000 Hz...") resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000) audio = resampler(audio) sampling_rate = 16000 inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits predicted_class_id = torch.argmax(logits, dim=-1).item() predicted_label = model.config.id2label[predicted_class_id] st.success(f"Predicted Class: {predicted_label}") os.remove(temp_file_path) except Exception as e: st.error(f"An error occurred: {e}") else: st.info("Please upload a .wav audio file to continue.")