QwenStoryteller2 / README.md
daniel3303's picture
Update README.md
a4b32a8 verified
|
raw
history blame
10.9 kB
---
language: en
license: apache-2.0
tags:
- vision-language-model
- visual-storytelling
- chain-of-thought
- grounded-text-generation
- cross-frame-consistency
- storytelling
- image-to-text
- contrastive-learning
- reinforcement-learning
- entity-reidentification
datasets:
- daniel3303/StoryReasoningAdversarialDPO
- daniel3303/StoryReasoning
metrics:
- precision
- recall
- bleu
- meteor
- rouge
- map
base_model:
- daniel3303/QwenStoryteller
pipeline_tag: image-to-text
model-index:
- name: QwenStoryteller2
results:
- task:
type: visual-storytelling
name: Visual Storytelling
dataset:
name: StoryReasoningAdversarialDPO
type: daniel3303/StoryReasoningAdversarialDPO
split: test
metrics:
- name: Character Precision
type: precision
value: 0.78
- name: Object Precision
type: precision
value: 0.29
- name: Total Precision
type: precision
value: 0.45
- name: mAP
type: mean_average_precision
value: 0.31
- name: Character Recall
type: recall
value: 0.77
- name: Object Recall
type: recall
value: 0.28
- name: Total Recall
type: recall
value: 0.48
- name: F1 Score
type: f1
value: 0.41
- name: METEOR
type: meteor
value: 0.17
- name: ROUGE-L
type: rouge-l
value: 0.18
- name: BLEU-4
type: bleu-4
value: 0.057
- name: Character Persistence (≥5 frames)
type: accuracy
value: 0.493
- name: Object Persistence (≥5 frames)
type: accuracy
value: 0.213
- name: Well-structured Stories
type: accuracy
value: 0.975
library_name: transformers
---
# QwenStoryteller2
QwenStoryteller2 is an improved version of QwenStoryteller, fine-tuned using contrastive reinforcement learning with Direct Preference Optimization (DPO) to achieve superior entity re-identification and visual grounding in cross-frame storytelling scenarios.
## Model Description
**Base Model:** QwenStoryteller (Qwen2.5-VL 7B)
**Training Method:** Contrastive Reinforcement Learning with Direct Preference Optimization (LoRA rank 2048, alpha 4096)
**Training Dataset:** [StoryReasoningAdversarialDPO](https://huggingface.co/datasets/daniel3303/StoryReasoningAdversarialDPO)
QwenStoryteller2 builds upon the original QwenStoryteller by addressing critical limitations in cross-frame entity consistency through:
- **Contrastive Learning:** Training on both real and synthetic negative story examples
- **Enhanced Entity Re-identification:** Improved tracking of characters and objects across frames
- **Better Grounding:** Superior alignment between narrative elements and visual entities
- **Reduced Hallucinations:** More reliable entity connections and fewer spurious references
The model employs a dual-component reward function that promotes appropriate entity connections in coherent sequences while discouraging incorrect connections in synthetic arrangements.
## Key Improvements Over QwenStoryteller
- **Grounding Performance:** mAP improved from 0.27 to 0.31 (+14.8%), F1 score from 0.35 to 0.41 (+17.1%)
- **Cross-frame Consistency:** Character persistence on ≥5 frames increased from 37.7% to 49.3% (+30.8%)
- **Pronoun Grounding:** Significant improvements across all pronoun types (he: 90.1%→99.1%, she: 91.1%→98.6%, they: 47.6%→68.8%)
- **Structural Quality:** Well-structured stories increased from 79.1% to 97.5% (+23.3%)
- **Entity Tracking:** Object persistence on ≥5 frames improved from 20.9% to 21.3%
## System Prompt
The model was trained with the following system prompt, and we recommend using it for optimal performance:
```
You are an AI storyteller that can analyze sequences of images and create creative narratives.
First think step-by-step to analyze characters, objects, settings, and narrative structure.
Then create a grounded story that maintains consistent character identity and object references across frames.
Use <think></think> tags to show your reasoning process before writing the final story.
```
## Usage
```python
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
# Load the model
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"daniel3303/QwenStoryteller2", torch_dtype="auto", device_map="auto"
)
# Load processor
processor = AutoProcessor.from_pretrained("daniel3303/QwenStoryteller2")
# Load images
images = [
Image.open("image1.jpg"),
Image.open("image2.jpg"),
Image.open("image3.jpg"),
Image.open("image4.jpg"),
Image.open("image5.jpg")
]
# Create image content list
image_content = []
for img in images:
image_content.append({
"type": "image",
"image": img,
})
# Add text prompt at the end
image_content.append({"type": "text", "text": "Generate a story based on these images."})
# Create messages with system prompt
messages = [
{
"role": "system",
"content": "You are an AI storyteller that can analyze sequences of images and create creative narratives. First think step-by-step to analyze characters, objects, settings, and narrative structure. Then create a grounded story that maintains consistent character identity and object references across frames. Use <think></think> tags to show your reasoning process before writing the final story."
},
{
"role": "user",
"content": image_content,
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
# Inference: Generation of the output
generated_ids = model.generate(
**inputs,
max_new_tokens=4096,
do_sample=True,
temperature=0.7,
top_p=0.9
)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
story = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(story)
```
### Using vLLM for faster inference
For significantly faster inference, you can use vLLM to serve the model:
```bash
# Install vLLM
pip install vllm
# Serve the model with vLLM
vllm serve daniel3303/QwenStoryteller2
```
## Training Methodology
### Contrastive Learning Framework
QwenStoryteller2 was trained using a novel contrastive reinforcement learning approach:
1. **Synthetic Story Generation:** Extended the StoryReasoning dataset with 4,178 synthetic stories created by sampling images from different movies to create incoherent sequences
2. **Dual-Component Reward Function:** Combined entity re-identification (R_reid) and grounding (R_ground) rewards with structural validation
3. **Direct Preference Optimization:** Used offline preference pairs generated from the reward function to train the model
### Reward Function Components
- **Entity Re-identification Reward:** Tracks character and object persistence across frames, promoting connections in real stories while penalizing them in synthetic ones
- **Grounding Reward:** Evaluates pronoun and proper noun grounding to visual entities
- **Structure Validation:** Ensures generated outputs maintain required format and consistency
### Training Configuration
- **Method:** Direct Preference Optimization (DPO) with LoRA fine-tuning
- **LoRA Parameters:** Rank 2048, alpha 4096
- **Optimizer:** AdamW with learning rate 5×10⁻⁶
- **Batch Size:** 8
- **Epochs:** 3
- **Temperature Parameter (β):** 0.1
## Performance Metrics
| Metric | QwenStoryteller | QwenStoryteller2 | Improvement |
|--------|-----------------|------------------|-------------|
| Character Precision | 0.83 | 0.78 | -6.0% |
| Object Precision | 0.46 | 0.29 | -37.0% |
| Total Precision | 0.57 | 0.45 | -21.1% |
| mAP | 0.27 | 0.31 | +14.8% |
| Character Recall | 0.62 | 0.77 | +24.2% |
| Object Recall | 0.25 | 0.28 | +12.0% |
| Total Recall | 0.40 | 0.48 | +20.0% |
| F1 Score | 0.35 | 0.41 | +17.1% |
| METEOR | 0.14 | 0.17 | +21.4% |
| ROUGE-L | 0.16 | 0.18 | +12.5% |
| BLEU-4 | 0.054 | 0.057 | +5.6% |
## Output Format
QwenStoryteller2 produces enhanced outputs with improved consistency:
1. **Chain-of-Thought Analysis (`<think></think>`):** More accurate structured analysis with:
- Improved character tables with consistent identity references
- Better object tracking with enhanced spatial coordination
- More reliable setting categorization
- Stronger narrative structure modeling
2. **Grounded Story:** Enhanced narrative with specialized XML tags:
- `<gdi>`: Image tags for specific frames
- `<gdo>`: Entity reference tags with improved accuracy
- `<gda>`: Action tags with better character-action alignment
- `<gdl>`: Location/landmark tags with enhanced spatial grounding
## Key Features
- **Enhanced Cross-Frame Consistency:** Superior character and object identity maintenance through contrastive learning
- **Improved Pronoun Grounding:** Better alignment of pronouns with visual entities (up to 99.1% for "he", 98.6% for "she")
- **Reduced Hallucinations:** Fewer incorrect entity connections and spurious references
- **Robust Entity Discrimination:** Learned ability to distinguish when cross-frame connections are appropriate
- **Better Structural Quality:** Near-perfect adherence to expected output format (97.5%)
## Limitations
- Precision scores show some reduction compared to the original model due to increased recall
- Training data derived from movies may introduce cinematic biases
- Entity re-identification still relies primarily on visual similarity within bounding boxes
- Performance validated only on 7B parameter scale
- Optimal real-to-synthetic story ratio (2:1) may not generalize to all scenarios
## Citation
```bibtex
TODO
@misc{oliveira2025storyreasoningdatasetusingchainofthought,
title={StoryReasoning Dataset: Using Chain-of-Thought for Scene Understanding and Grounded Story Generation},
author={Daniel A. P. Oliveira and David Martins de Matos},
year={2025},
eprint={2505.10292},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2505.10292}
}
```
## Contact
For questions or feedback regarding this model, please contact:
- Daniel A. P. Oliveira ([email protected])