Monke64 commited on
Commit
fdd26e9
·
1 Parent(s): 5685b90

Uncommented Load image code

Browse files
Files changed (2) hide show
  1. .idea/.name +1 -0
  2. app.py +7 -7
.idea/.name ADDED
@@ -0,0 +1 @@
 
 
1
+ app.py
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  from flask.Emotion_spotting_service import _Emotion_spotting_service
3
  from flask.Genre_spotting_service import _Genre_spotting_service
4
  from flask.Beat_tracking_service import _Beat_tracking_service
5
- #from diffusers import StableDiffusionPipeline
6
  import torch
7
  import os
8
 
@@ -23,11 +23,11 @@ def load_beat_model():
23
  beat_service = _Beat_tracking_service()
24
  return beat_service
25
 
26
- # @st.cache_resource
27
- # def load_image_model():
28
- # pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
29
- # pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
30
- # return pipeline
31
 
32
 
33
  if 'emotion' not in st.session_state:
@@ -42,7 +42,7 @@ if 'beat' not in st.session_state:
42
  emotion_service = load_emo_model()
43
  genre_service = load_genre_model()
44
  beat_service = load_beat_model()
45
- # image_service = load_image_model()
46
 
47
  st.title("Music2Image webpage")
48
  user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader")
 
2
  from flask.Emotion_spotting_service import _Emotion_spotting_service
3
  from flask.Genre_spotting_service import _Genre_spotting_service
4
  from flask.Beat_tracking_service import _Beat_tracking_service
5
+ from diffusers import StableDiffusionPipeline
6
  import torch
7
  import os
8
 
 
23
  beat_service = _Beat_tracking_service()
24
  return beat_service
25
 
26
+ @st.cache_resource
27
+ def load_image_model():
28
+ pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
29
+ pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
30
+ return pipeline
31
 
32
 
33
  if 'emotion' not in st.session_state:
 
42
  emotion_service = load_emo_model()
43
  genre_service = load_genre_model()
44
  beat_service = load_beat_model()
45
+ image_service = load_image_model()
46
 
47
  st.title("Music2Image webpage")
48
  user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader")