QwenStoryteller / app.py
LPX55's picture
Update app.py
c541413 verified
raw
history blame
2.99 kB
import spaces
import gradio as gr
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import torch
import os, time
# Load the model and processor
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"daniel3303/QwenStoryteller",
torch_dtype=torch.float16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller")
@spaces.GPU()
@torch.no_grad()
def generate_story(file_paths):
# Load images from the file paths
images = [Image.open(file_path) for file_path in file_paths]
image_content = []
for img in images[:6]: # Limit to 6 images
image_content.append({
"type": "image",
"image": img,
})
image_content.append({"type": "text", "text": "Generate a story based on these images."})
messages = [
{
"role": "system",
"content": "You are an AI storyteller that can analyze sequences of images and create creative narratives. First think step-by-step to analyze characters, objects, settings, and narrative structure. Then create a grounded story that maintains consistent character identity and object references across frames. Use 🧠 tags to show your reasoning process before writing the final story."
},
{
"role": "user",
"content": image_content,
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
generated_ids = model.generate(
**inputs,
max_new_tokens=4096,
do_sample=True,
temperature=0.7,
top_p=0.9
)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
story = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return story
with gr.Blocks() as demo:
gr.Markdown("# Qwen Storyteller \n## Upload up to 6 images to generate a creative story.")
with gr.Row():
with gr.Column():
upload_button = gr.UploadButton("Upload up to 6 images", file_types=["image"], file_count="multiple")
output_file = gr.File(label="Uploaded Files")
with gr.Column():
outputs = gr.Textbox(label="Generated Story", lines=10)
upload_button.upload(lambda files: [f.name for f in files], upload_button, output_file)
gen_button = gr.Button("Generate", variant="secondary")
gen_button.click(generate_story, upload_button, outputs)
if __name__ == "__main__":
demo.queue().launch(show_errors=True)