Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
from PIL import Image | |
import torch | |
import os | |
# Load the model and processor | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
"daniel3303/QwenStoryteller", | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller") | |
def generate_story(file_paths): | |
# Load images from the file paths | |
images = [Image.open(file_path) for file_path in file_paths] | |
image_content = [] | |
for img in images[:6]: # Limit to 6 images | |
image_content.append({ | |
"type": "image", | |
"image": img, | |
}) | |
image_content.append({"type": "text", "text": "Generate a story based on these images."}) | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are an AI storyteller that can analyze sequences of images and create creative narratives. First think step-by-step to analyze characters, objects, settings, and narrative structure. Then create a grounded story that maintains consistent character identity and object references across frames. Use π§ tags to show your reasoning process before writing the final story." | |
}, | |
{ | |
"role": "user", | |
"content": image_content, | |
} | |
] | |
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt" | |
) | |
inputs = inputs.to(model.device) | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=4096, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
story = processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
)[0] | |
return story | |
with gr.Blocks() as demo: | |
gr.Markdown("# Qwen Storyteller \n## Upload up to 6 images to generate a creative story.") | |
with gr.Row(): | |
with gr.Column(): | |
upload_button = gr.UploadButton("Upload up to 6 images", file_types=["image"], file_count="multiple") | |
output_file = gr.File(label="Uploaded Files") | |
with gr.Column(): | |
outputs = gr.Textbox(label="Generated Story", lines=10) | |
upload_button.upload(lambda files: [f.name for f in files], upload_button, output_file) | |
gen_button = gr.Button("Generate", variant="secondary") | |
gen_button.click(generate_story, upload_button, outputs) | |
if __name__ == "__main__": | |
demo.launch() | |