File size: 3,163 Bytes
1580527
 
 
 
a08179c
 
1580527
d203a00
1580527
a08179c
 
 
 
1580527
 
 
 
 
 
 
 
 
 
 
 
 
 
a08179c
 
 
 
 
1580527
 
 
 
 
 
 
 
 
 
 
b9b384d
 
 
1580527
 
 
a08179c
1580527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6360022
 
 
 
1580527
b9b384d
 
 
 
a08179c
b9b384d
a08179c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import streamlit as st
from flask.Emotion_spotting_service import _Emotion_spotting_service
from flask.Genre_spotting_service import _Genre_spotting_service
from flask.Beat_tracking_service import _Beat_tracking_service
from diffusers import StableDiffusionPipeline
import tensorflow as tf
import torch
import os

physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

@st.cache_resource
def load_emo_model():
    emo_service = _Emotion_spotting_service("flask/emotion_model.h5")
    return emo_service
@st.cache_resource
def load_genre_model():
    gen_service = _Genre_spotting_service("flask/Genre_classifier_model.h5")
    return gen_service

@st.cache_resource
def load_beat_model():
    beat_service = _Beat_tracking_service()
    return beat_service

@st.cache_resource
def load_image_model():
     pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
     pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
     return pipeline


if 'emotion' not in st.session_state:
    st.session_state.emotion = None

if 'genre' not in st.session_state:
    st.session_state.genre = None

if 'beat' not in st.session_state:
    st.session_state.beat = None

if "text_prompt" not in st.session_state:
    st.session_state.text_prompt = None

emotion_service = load_emo_model()
genre_service = load_genre_model()
beat_service = load_beat_model()
image_service = load_image_model()

st.title("Music2Image webpage")
user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader")
st.caption("Generate images from your audio file")
st.audio(user_input)
c1,c2,c3 = st.columns([1,1,1])
with c1:
    if st.button("Generate emotion"):
        emotion = emotion_service.predict(user_input)
        st.session_state.emotion = emotion
    st.text(st.session_state.emotion)
with c2:
    if st.button("Generate genre"):
        genre = genre_service.predict(user_input)
        st.session_state.genre = genre
    st.text(st.session_state.genre)

with c3:
    if st.button("Generate beat"):
        beat = beat_service.get_beat(user_input)
        st.session_state.beat = beat
    st.text(st.session_state.beat)

if st.session_state.emotion != None and st.session_state.genre != None and st.session_state.beat != None:
    if st.button("Generate text description to be fed into stable diffusion"):
        if st.session_state.beat > 100:
            speed = "medium and steady"
        else:
            speed = "slow and calm"
        st.caption("Text description of your music file")
        text = "A scenic image that describes a " + speed + " pace with a feeling of" + st.session_state.emotion + "."
        st.session_state.text_prompt = text
        st.text(st.session_state.text_prompt)
    if st.session_state.text_prompt:
       if st.button("Generate image from text description"):
            image = image_service(st.session_state.text_prompt)
            st.image(image)