zaafirriaz commited on
Commit
b300b18
·
verified ·
1 Parent(s): f945833

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -11
app.py CHANGED
@@ -5,13 +5,31 @@ from PIL import Image
5
  import torch
6
 
7
  # Initialize the text generation model
8
- text_generator = pipeline('text-generation', model='gpt2')
 
 
 
 
 
 
9
 
10
  # Initialize the image generation model
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- image_generator = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
 
 
 
 
 
 
 
 
 
13
 
14
  def generate_blog(title):
 
 
 
15
  # Generate blog content
16
  blog_content = text_generator(title, max_length=500, num_return_sequences=1)[0]['generated_text']
17
 
@@ -28,11 +46,11 @@ title = st.text_input('Enter the title of your blog:')
28
  if title:
29
  with st.spinner('Generating blog content and image...'):
30
  blog_content, image = generate_blog(title)
31
- st.success('Blog generated successfully!')
32
- st.subheader('Blog Content')
33
- st.write(blog_content)
34
- st.subheader('Generated Image')
35
- st.image(image, caption='Generated Image')
36
-
37
- if __name__ == '__main__':
38
- st.run()
 
5
  import torch
6
 
7
  # Initialize the text generation model
8
+ def initialize_text_generator():
9
+ try:
10
+ text_generator = pipeline('text-generation', model='gpt2')
11
+ except Exception as e:
12
+ st.error(f"Error loading text generation model: {e}")
13
+ return None
14
+ return text_generator
15
 
16
  # Initialize the image generation model
17
+ def initialize_image_generator():
18
+ try:
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ image_generator = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
21
+ except Exception as e:
22
+ st.error(f"Error loading image generation model: {e}")
23
+ return None
24
+ return image_generator
25
+
26
+ text_generator = initialize_text_generator()
27
+ image_generator = initialize_image_generator()
28
 
29
  def generate_blog(title):
30
+ if text_generator is None or image_generator is None:
31
+ return "Failed to load models", None
32
+
33
  # Generate blog content
34
  blog_content = text_generator(title, max_length=500, num_return_sequences=1)[0]['generated_text']
35
 
 
46
  if title:
47
  with st.spinner('Generating blog content and image...'):
48
  blog_content, image = generate_blog(title)
49
+ if blog_content == "Failed to load models":
50
+ st.error(blog_content)
51
+ else:
52
+ st.success('Blog generated successfully!')
53
+ st.subheader('Blog Content')
54
+ st.write(blog_content)
55
+ st.subheader('Generated Image')
56
+ st.image(image, caption='Generated Image')