Spaces:
Paused
Paused
import streamlit as st | |
from flask.Emotion_spotting_service import _Emotion_spotting_service | |
from flask.Genre_spotting_service import _Genre_spotting_service | |
from flask.Beat_tracking_service import _Beat_tracking_service | |
from diffusers import StableDiffusionPipeline | |
import tensorflow as tf | |
import torch | |
import os | |
physical_devices = tf.config.experimental.list_physical_devices('GPU') | |
if len(physical_devices) > 0: | |
tf.config.experimental.set_memory_growth(physical_devices[0], True) | |
def load_emo_model(): | |
emo_service = _Emotion_spotting_service("flask/emotion_model.h5") | |
return emo_service | |
def load_genre_model(): | |
gen_service = _Genre_spotting_service("flask/Genre_classifier_model.h5") | |
return gen_service | |
def load_beat_model(): | |
beat_service = _Beat_tracking_service() | |
return beat_service | |
def load_image_model(): | |
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda") | |
pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors") | |
return pipeline | |
if 'emotion' not in st.session_state: | |
st.session_state.emotion = None | |
if 'genre' not in st.session_state: | |
st.session_state.genre = None | |
if 'beat' not in st.session_state: | |
st.session_state.beat = None | |
if "text_prompt" not in st.session_state: | |
st.session_state.text_prompt = None | |
emotion_service = load_emo_model() | |
genre_service = load_genre_model() | |
beat_service = load_beat_model() | |
image_service = load_image_model() | |
st.title("Music2Image webpage") | |
user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader") | |
st.caption("Generate images from your audio file") | |
st.audio(user_input) | |
c1,c2,c3 = st.columns([1,1,1]) | |
with c1: | |
if st.button("Generate emotion"): | |
emotion = emotion_service.predict(user_input) | |
st.session_state.emotion = emotion | |
st.text(st.session_state.emotion) | |
with c2: | |
if st.button("Generate genre"): | |
genre = genre_service.predict(user_input) | |
st.session_state.genre = genre | |
st.text(st.session_state.genre) | |
with c3: | |
if st.button("Generate beat"): | |
beat = beat_service.get_beat(user_input) | |
st.session_state.beat = beat | |
st.text(st.session_state.beat) | |
if st.session_state.emotion != None and st.session_state.genre != None and st.session_state.beat != None: | |
if st.button("Generate text description to be fed into stable diffusion"): | |
if st.session_state.beat > 100: | |
speed = "medium and steady" | |
else: | |
speed = "slow and calm" | |
st.caption("Text description of your music file") | |
text = "A scenic image that describes a " + speed + " pace with a feeling of" + st.session_state.emotion + "." | |
st.session_state.text_prompt = text | |
st.text(st.session_state.text_prompt) | |
if st.session_state.text_prompt: | |
if st.button("Generate image from text description"): | |
image = image_service(st.session_state.text_prompt).images[0] | |
st.image(image) | |