LPX55 commited on
Commit
7fe0752
Β·
verified Β·
1 Parent(s): fc9b818

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -17
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
4
  from qwen_vl_utils import process_vision_info
5
  from PIL import Image
6
  import torch
 
7
 
8
  # Load the model and processor
9
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -13,23 +14,21 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
13
  )
14
  processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller")
15
 
16
- def upload_file(files):
17
- file_paths = [file.name for file in files]
18
- return file_paths
19
-
20
  @spaces.GPU()
21
- def generate_story(images):
 
 
 
 
22
  image_content = []
23
- for img in images[:6]:
24
  image_content.append({
25
  "type": "image",
26
  "image": img,
27
  })
28
 
29
- # Add text prompt at the end
30
  image_content.append({"type": "text", "text": "Generate a story based on these images."})
31
 
32
- # Create messages with system prompt
33
  messages = [
34
  {
35
  "role": "system",
@@ -41,7 +40,6 @@ def generate_story(images):
41
  }
42
  ]
43
 
44
- # Preparation for inference
45
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
46
  image_inputs, video_inputs = process_vision_info(messages)
47
  inputs = processor(
@@ -53,7 +51,6 @@ def generate_story(images):
53
  )
54
  inputs = inputs.to(model.device)
55
 
56
- # Inference: Generate the output
57
  generated_ids = model.generate(
58
  **inputs,
59
  max_new_tokens=4096,
@@ -61,6 +58,7 @@ def generate_story(images):
61
  temperature=0.7,
62
  top_p=0.9
63
  )
 
64
  generated_ids_trimmed = [
65
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
66
  ]
@@ -74,18 +72,19 @@ def generate_story(images):
74
 
75
  with gr.Blocks() as demo:
76
  gr.Markdown("# Qwen Storyteller \n## Upload up to 6 images to generate a creative story.")
 
77
  with gr.Row():
78
  with gr.Column():
79
- file_output = gr.File()
80
- upload_button = gr.UploadButton("Upload up to 6 images", file_types=["image", "video"], file_count="multiple")
81
- gen_button = gr.Button("Generate", variant="secondary")
82
 
83
  with gr.Column():
84
- outputs=gr.Textbox(label="Generated Story", lines=10)
85
 
86
- upload_button.upload(upload_file, upload_button, file_output)
87
- gen_button.click(fn=generate_story, file_output, outputs)
88
-
 
89
 
90
  if __name__ == "__main__":
91
  demo.launch()
 
4
  from qwen_vl_utils import process_vision_info
5
  from PIL import Image
6
  import torch
7
+ import os
8
 
9
  # Load the model and processor
10
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
14
  )
15
  processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller")
16
 
 
 
 
 
17
  @spaces.GPU()
18
+ @torch.no_grad()
19
+ def generate_story(file_paths):
20
+ # Load images from the file paths
21
+ images = [Image.open(file_path) for file_path in file_paths]
22
+
23
  image_content = []
24
+ for img in images[:6]: # Limit to 6 images
25
  image_content.append({
26
  "type": "image",
27
  "image": img,
28
  })
29
 
 
30
  image_content.append({"type": "text", "text": "Generate a story based on these images."})
31
 
 
32
  messages = [
33
  {
34
  "role": "system",
 
40
  }
41
  ]
42
 
 
43
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
44
  image_inputs, video_inputs = process_vision_info(messages)
45
  inputs = processor(
 
51
  )
52
  inputs = inputs.to(model.device)
53
 
 
54
  generated_ids = model.generate(
55
  **inputs,
56
  max_new_tokens=4096,
 
58
  temperature=0.7,
59
  top_p=0.9
60
  )
61
+
62
  generated_ids_trimmed = [
63
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
64
  ]
 
72
 
73
  with gr.Blocks() as demo:
74
  gr.Markdown("# Qwen Storyteller \n## Upload up to 6 images to generate a creative story.")
75
+
76
  with gr.Row():
77
  with gr.Column():
78
+ upload_button = gr.UploadButton("Upload up to 6 images", file_types=["image"], file_count="multiple")
79
+ output_file = gr.File(label="Uploaded Files")
 
80
 
81
  with gr.Column():
82
+ outputs = gr.Textbox(label="Generated Story", lines=10)
83
 
84
+ upload_button.upload(lambda files: [f.name for f in files], upload_button, output_file)
85
+
86
+ gen_button = gr.Button("Generate", variant="secondary")
87
+ gen_button.click(generate_story, upload_button, outputs)
88
 
89
  if __name__ == "__main__":
90
  demo.launch()