Upload 8 files
Browse files- README.md +115 -13
- app.py +80 -0
- evaluate_model.py +151 -0
- huggingface-metadata.json +12 -0
- requirements.txt +8 -0
- run_server.py +51 -0
- story_generator.py +142 -0
- test_server.py +101 -0
README.md
CHANGED
@@ -1,13 +1,115 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Rawi Kids Vision-Language Model
|
2 |
+
|
3 |
+
A vision-language model that generates engaging short stories for children (ages 6-12) based on images. This project is designed to be integrated with the Rawi Kids Flutter application and uses the DeepSeek Vision API.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
|
7 |
+
- Generate age-appropriate stories from images
|
8 |
+
- Support for different age groups (6-8 and 9-12 years)
|
9 |
+
- Optional themes to influence story generation (adventure, fantasy, animals, etc.)
|
10 |
+
- Gradio web interface for easy testing
|
11 |
+
- Integration with Flutter app
|
12 |
+
- Uses DeepSeek Vision API (no local model needed)
|
13 |
+
|
14 |
+
## Demo
|
15 |
+
|
16 |
+
This model can be tested using the Gradio web interface included in the project.
|
17 |
+
|
18 |
+
## Setup and Installation
|
19 |
+
|
20 |
+
### Prerequisites
|
21 |
+
|
22 |
+
- Python 3.8 or higher
|
23 |
+
- pip (Python package manager)
|
24 |
+
- Virtual environment (recommended)
|
25 |
+
- DeepSeek API Key
|
26 |
+
|
27 |
+
### Getting a DeepSeek API Key
|
28 |
+
|
29 |
+
1. Visit the [DeepSeek website](https://www.deepseek.com/) and sign up for an account
|
30 |
+
2. Navigate to your API settings page to obtain an API key
|
31 |
+
3. Copy the API key for use in the next steps
|
32 |
+
|
33 |
+
### Installation
|
34 |
+
|
35 |
+
1. Clone this repository
|
36 |
+
```
|
37 |
+
git clone <repository-url>
|
38 |
+
cd rawi-kids-vlm
|
39 |
+
```
|
40 |
+
|
41 |
+
2. Create and activate a virtual environment
|
42 |
+
```
|
43 |
+
python -m venv venv
|
44 |
+
# On Windows
|
45 |
+
venv\Scripts\activate
|
46 |
+
# On macOS/Linux
|
47 |
+
source venv/bin/activate
|
48 |
+
```
|
49 |
+
|
50 |
+
3. Install the required packages
|
51 |
+
```
|
52 |
+
pip install -r requirements.txt
|
53 |
+
```
|
54 |
+
|
55 |
+
4. Create a `.env` file and add your DeepSeek API key
|
56 |
+
```
|
57 |
+
echo "DEEPSEEK_API_KEY=your_api_key_here" > .env
|
58 |
+
```
|
59 |
+
|
60 |
+
5. Run the Gradio app
|
61 |
+
```
|
62 |
+
python app.py
|
63 |
+
```
|
64 |
+
|
65 |
+
The interface will be available at http://localhost:7860
|
66 |
+
|
67 |
+
## Using the Interface
|
68 |
+
|
69 |
+
1. Upload an image using the file uploader
|
70 |
+
2. Select the target age group (6-8 or 9-12 years)
|
71 |
+
3. Choose a story theme (optional)
|
72 |
+
4. Click "Generate Story"
|
73 |
+
5. The model will analyze the image and generate an age-appropriate story
|
74 |
+
|
75 |
+
## Flutter Integration
|
76 |
+
|
77 |
+
See the `test_server.py` file for examples of how to integrate with your Flutter app. You'll need to implement an API client in your Flutter app that sends images to this service and receives the generated stories.
|
78 |
+
|
79 |
+
## Testing
|
80 |
+
|
81 |
+
You can test the model using the provided test script:
|
82 |
+
|
83 |
+
```
|
84 |
+
python test_server.py --url http://localhost:7860 --image path/to/test_image.jpg
|
85 |
+
```
|
86 |
+
|
87 |
+
## Evaluation
|
88 |
+
|
89 |
+
For more detailed evaluation of the model's performance, use the evaluation script:
|
90 |
+
|
91 |
+
```
|
92 |
+
python evaluate_model.py --images test_images --output evaluation_results.json
|
93 |
+
```
|
94 |
+
|
95 |
+
## Deploying to Hugging Face Spaces
|
96 |
+
|
97 |
+
This project is designed to work with Hugging Face Spaces, which provides free hosting for machine learning demos.
|
98 |
+
|
99 |
+
1. Create a new Space on Hugging Face
|
100 |
+
2. Select "Gradio" as the SDK
|
101 |
+
3. Push this repository to the Space
|
102 |
+
4. Add your DeepSeek API key as a secret in the Space configuration
|
103 |
+
5. The app will automatically deploy and be available at your Space URL
|
104 |
+
|
105 |
+
## Important Note on API Usage
|
106 |
+
|
107 |
+
The DeepSeek API is a commercial service and may have usage limits or costs associated with it. Make sure to check their pricing and terms of service to understand any potential costs for your usage level.
|
108 |
+
|
109 |
+
## License
|
110 |
+
|
111 |
+
[Add your license information here]
|
112 |
+
|
113 |
+
## Contact
|
114 |
+
|
115 |
+
[Add your contact information here]
|
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from story_generator import StoryGenerator
|
3 |
+
import tempfile
|
4 |
+
import os
|
5 |
+
|
6 |
+
# Initialize the story generator
|
7 |
+
story_generator = StoryGenerator()
|
8 |
+
|
9 |
+
# Define the available themes
|
10 |
+
THEMES = ["None", "Adventure", "Fantasy", "Animals", "Friendship", "Science"]
|
11 |
+
|
12 |
+
def generate_story(image, age_group, theme):
|
13 |
+
"""
|
14 |
+
Generate a story from an image using the story generator
|
15 |
+
|
16 |
+
Args:
|
17 |
+
image: The uploaded image
|
18 |
+
age_group: The target age group
|
19 |
+
theme: The story theme
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
str: The generated story
|
23 |
+
"""
|
24 |
+
# Save the image to a temporary file
|
25 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp:
|
26 |
+
image.save(temp.name)
|
27 |
+
temp_filename = temp.name
|
28 |
+
|
29 |
+
try:
|
30 |
+
# Process the theme (convert "None" to None)
|
31 |
+
processed_theme = None if theme == "None" else theme.lower()
|
32 |
+
|
33 |
+
# Open the image file and generate the story
|
34 |
+
with open(temp_filename, 'rb') as img_file:
|
35 |
+
story = story_generator.generate(img_file, age_group, processed_theme)
|
36 |
+
|
37 |
+
return story
|
38 |
+
except Exception as e:
|
39 |
+
return f"Error generating story: {str(e)}"
|
40 |
+
finally:
|
41 |
+
# Clean up the temporary file
|
42 |
+
if os.path.exists(temp_filename):
|
43 |
+
os.unlink(temp_filename)
|
44 |
+
|
45 |
+
# Create the Gradio interface
|
46 |
+
with gr.Blocks(title="Rawi Kids Story Generator") as demo:
|
47 |
+
gr.Markdown("# Rawi Kids Story Generator")
|
48 |
+
gr.Markdown("Upload an image and get a story for kids!")
|
49 |
+
|
50 |
+
with gr.Row():
|
51 |
+
with gr.Column(scale=1):
|
52 |
+
# Input components
|
53 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
54 |
+
age_group = gr.Radio(choices=["6-8", "9-12"], value="6-8", label="Age Group (years)")
|
55 |
+
theme = gr.Dropdown(choices=THEMES, value="None", label="Story Theme")
|
56 |
+
submit_btn = gr.Button("Generate Story", variant="primary")
|
57 |
+
|
58 |
+
with gr.Column(scale=1):
|
59 |
+
# Output component
|
60 |
+
story_output = gr.Textbox(label="Generated Story", lines=10)
|
61 |
+
|
62 |
+
# Set up the button click event
|
63 |
+
submit_btn.click(
|
64 |
+
fn=generate_story,
|
65 |
+
inputs=[image_input, age_group, theme],
|
66 |
+
outputs=story_output
|
67 |
+
)
|
68 |
+
|
69 |
+
gr.Markdown("""
|
70 |
+
### How it works
|
71 |
+
|
72 |
+
1. Upload a picture or take a photo
|
73 |
+
2. Select the age group (6-8 or 9-12 years)
|
74 |
+
3. Choose a theme for the story (optional)
|
75 |
+
4. Click "Generate Story"
|
76 |
+
5. The AI will analyze the image and create a story for kids!
|
77 |
+
""")
|
78 |
+
|
79 |
+
# For Hugging Face Spaces
|
80 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
evaluate_model.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
from PIL import Image
|
5 |
+
import glob
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
from story_generator import StoryGenerator
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
import sys
|
11 |
+
|
12 |
+
# Load environment variables for API key
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
class ModelEvaluator:
|
16 |
+
def __init__(self, images_dir, output_file):
|
17 |
+
"""Initialize the evaluator with paths to images and output file"""
|
18 |
+
self.images_dir = images_dir
|
19 |
+
self.output_file = output_file
|
20 |
+
|
21 |
+
# Check for API key
|
22 |
+
if not os.getenv("DEEPSEEK_API_KEY"):
|
23 |
+
print("ERROR: DEEPSEEK_API_KEY environment variable not found.")
|
24 |
+
print("Please set your DeepSeek API key using:")
|
25 |
+
print(" - Create a .env file with DEEPSEEK_API_KEY=your_key_here")
|
26 |
+
print(" - Or set the environment variable directly")
|
27 |
+
sys.exit(1)
|
28 |
+
|
29 |
+
# Initialize story generator
|
30 |
+
self.generator = StoryGenerator()
|
31 |
+
|
32 |
+
# Create output directory if it doesn't exist
|
33 |
+
os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True)
|
34 |
+
|
35 |
+
def evaluate_all(self, limit=None):
|
36 |
+
"""Evaluate the model on all images in the directory"""
|
37 |
+
image_files = glob.glob(os.path.join(self.images_dir, "*.jpg")) + \
|
38 |
+
glob.glob(os.path.join(self.images_dir, "*.jpeg")) + \
|
39 |
+
glob.glob(os.path.join(self.images_dir, "*.png"))
|
40 |
+
|
41 |
+
# Limit the number of images if specified (to control API usage)
|
42 |
+
if limit and limit > 0:
|
43 |
+
image_files = image_files[:limit]
|
44 |
+
|
45 |
+
print(f"Found {len(image_files)} images for evaluation")
|
46 |
+
print(f"NOTE: Using DeepSeek API - API call charges may apply")
|
47 |
+
|
48 |
+
results = []
|
49 |
+
|
50 |
+
for idx, img_path in enumerate(image_files):
|
51 |
+
print(f"\nProcessing image {idx + 1}/{len(image_files)}: {os.path.basename(img_path)}")
|
52 |
+
|
53 |
+
# Test with different age groups and themes (use fewer combinations to limit API calls)
|
54 |
+
for age_group in ["6-8", "9-12"]:
|
55 |
+
# Limit theme testing to save on API calls
|
56 |
+
for theme in [None, "adventure"]:
|
57 |
+
theme_str = theme if theme else "none"
|
58 |
+
print(f" Generating story for age group: {age_group}, theme: {theme_str}")
|
59 |
+
|
60 |
+
try:
|
61 |
+
start_time = time.time()
|
62 |
+
with open(img_path, 'rb') as img_file:
|
63 |
+
story = self.generator.generate(img_file, age_group, theme)
|
64 |
+
generation_time = time.time() - start_time
|
65 |
+
|
66 |
+
# Record the result
|
67 |
+
result = {
|
68 |
+
"image_path": img_path,
|
69 |
+
"age_group": age_group,
|
70 |
+
"theme": theme_str,
|
71 |
+
"generation_time_seconds": round(generation_time, 2),
|
72 |
+
"story_length_chars": len(story),
|
73 |
+
"story_words": len(story.split()),
|
74 |
+
"story": story
|
75 |
+
}
|
76 |
+
|
77 |
+
results.append(result)
|
78 |
+
|
79 |
+
# Print summary
|
80 |
+
print(f" Time: {result['generation_time_seconds']:.2f}s, "
|
81 |
+
f"Words: {result['story_words']}")
|
82 |
+
|
83 |
+
except Exception as e:
|
84 |
+
print(f" Error generating story: {str(e)}")
|
85 |
+
results.append({
|
86 |
+
"image_path": img_path,
|
87 |
+
"age_group": age_group,
|
88 |
+
"theme": theme_str,
|
89 |
+
"error": str(e)
|
90 |
+
})
|
91 |
+
|
92 |
+
# Save all results to file
|
93 |
+
with open(self.output_file, 'w') as f:
|
94 |
+
json.dump(results, f, indent=2)
|
95 |
+
|
96 |
+
print(f"\nEvaluation complete. Results saved to {self.output_file}")
|
97 |
+
return results
|
98 |
+
|
99 |
+
def print_summary(self, results):
|
100 |
+
"""Print a summary of the evaluation results"""
|
101 |
+
if not results:
|
102 |
+
print("No results to summarize")
|
103 |
+
return
|
104 |
+
|
105 |
+
successful_generations = [r for r in results if "error" not in r]
|
106 |
+
error_generations = [r for r in results if "error" in r]
|
107 |
+
|
108 |
+
print(f"\nEvaluation Summary:")
|
109 |
+
print(f" Total images processed: {len(set([r['image_path'] for r in results]))}")
|
110 |
+
print(f" Total story generations: {len(results)}")
|
111 |
+
print(f" Successful generations: {len(successful_generations)}")
|
112 |
+
print(f" Failed generations: {len(error_generations)}")
|
113 |
+
|
114 |
+
if successful_generations:
|
115 |
+
avg_time = sum([r["generation_time_seconds"] for r in successful_generations]) / len(successful_generations)
|
116 |
+
avg_words = sum([r["story_words"] for r in successful_generations]) / len(successful_generations)
|
117 |
+
print(f" Average generation time: {avg_time:.2f} seconds")
|
118 |
+
print(f" Average story length: {avg_words:.1f} words")
|
119 |
+
|
120 |
+
# Analysis by age group
|
121 |
+
age_groups = ["6-8", "9-12"]
|
122 |
+
for age_group in age_groups:
|
123 |
+
age_results = [r for r in successful_generations if r["age_group"] == age_group]
|
124 |
+
if age_results:
|
125 |
+
avg_words = sum([r["story_words"] for r in age_results]) / len(age_results)
|
126 |
+
print(f" Age group {age_group}: {len(age_results)} stories, avg {avg_words:.1f} words")
|
127 |
+
|
128 |
+
# Analysis by theme
|
129 |
+
themes = ["none", "adventure", "fantasy", "animals"]
|
130 |
+
for theme in themes:
|
131 |
+
theme_results = [r for r in successful_generations if r["theme"] == theme]
|
132 |
+
if theme_results:
|
133 |
+
avg_words = sum([r["story_words"] for r in theme_results]) / len(theme_results)
|
134 |
+
print(f" Theme {theme}: {len(theme_results)} stories, avg {avg_words:.1f} words")
|
135 |
+
|
136 |
+
|
137 |
+
def main():
|
138 |
+
parser = argparse.ArgumentParser(description='Evaluate the story generation model')
|
139 |
+
parser.add_argument('--images', default='test_images', help='Directory containing test images')
|
140 |
+
parser.add_argument('--output', default='evaluation_results.json', help='Output file for results')
|
141 |
+
parser.add_argument('--limit', type=int, default=2, help='Limit the number of images to process (to control API usage)')
|
142 |
+
|
143 |
+
args = parser.parse_args()
|
144 |
+
|
145 |
+
evaluator = ModelEvaluator(args.images, args.output)
|
146 |
+
results = evaluator.evaluate_all(limit=args.limit)
|
147 |
+
evaluator.print_summary(results)
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == '__main__':
|
151 |
+
main()
|
huggingface-metadata.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"title": "Rawi Kids Story Generator",
|
3 |
+
"emoji": "📚",
|
4 |
+
"colorFrom": "blue",
|
5 |
+
"colorTo": "purple",
|
6 |
+
"sdk": "gradio",
|
7 |
+
"sdk_version": "3.50.2",
|
8 |
+
"python_version": "3.10",
|
9 |
+
"app_file": "app.py",
|
10 |
+
"pinned": false,
|
11 |
+
"license": "mit"
|
12 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flask==2.0.1
|
2 |
+
pillow==9.5.0
|
3 |
+
python-dotenv==1.0.0
|
4 |
+
flask-cors==3.0.10
|
5 |
+
gunicorn==20.1.0
|
6 |
+
numpy==1.24.3
|
7 |
+
requests==2.31.0
|
8 |
+
gradio==3.50.2
|
run_server.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import argparse
|
5 |
+
import subprocess
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
def main():
|
9 |
+
parser = argparse.ArgumentParser(description='Run the Vision Language Model server')
|
10 |
+
parser.add_argument('--port', type=int, help='Port to run the server on')
|
11 |
+
parser.add_argument('--debug', action='store_true', help='Run in debug mode')
|
12 |
+
parser.add_argument('--host', default='0.0.0.0', help='Host to run the server on')
|
13 |
+
parser.add_argument('--workers', type=int, default=1, help='Number of Gunicorn workers')
|
14 |
+
parser.add_argument('--use-gunicorn', action='store_true', help='Use Gunicorn for production')
|
15 |
+
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
# Load environment variables
|
19 |
+
load_dotenv()
|
20 |
+
|
21 |
+
# Set environment variables from command line arguments
|
22 |
+
if args.port:
|
23 |
+
os.environ['PORT'] = str(args.port)
|
24 |
+
|
25 |
+
if args.debug:
|
26 |
+
os.environ['DEBUG'] = 'True'
|
27 |
+
|
28 |
+
port = int(os.environ.get('PORT', 5000))
|
29 |
+
debug = os.environ.get('DEBUG', 'False').lower() == 'true'
|
30 |
+
|
31 |
+
print(f"Starting server on {args.host}:{port}")
|
32 |
+
print(f"Debug mode: {debug}")
|
33 |
+
|
34 |
+
if args.use_gunicorn:
|
35 |
+
# Use Gunicorn for production
|
36 |
+
cmd = [
|
37 |
+
'gunicorn',
|
38 |
+
'--bind', f"{args.host}:{port}",
|
39 |
+
'--workers', str(args.workers),
|
40 |
+
'app:app'
|
41 |
+
]
|
42 |
+
print(f"Running with gunicorn: {' '.join(cmd)}")
|
43 |
+
subprocess.call(cmd)
|
44 |
+
else:
|
45 |
+
# Use Flask's built-in server
|
46 |
+
from app import app
|
47 |
+
app.run(host=args.host, port=port, debug=debug)
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
main()
|
story_generator.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
import base64
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
import json
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
# Load environment variables for API keys
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
class StoryGenerator:
|
13 |
+
def __init__(self):
|
14 |
+
"""Initialize the story generator with DeepSeek API configuration"""
|
15 |
+
self.api_key = os.getenv("DEEPSEEK_API_KEY")
|
16 |
+
if not self.api_key:
|
17 |
+
print("Warning: DEEPSEEK_API_KEY not found in environment variables. Please set it.")
|
18 |
+
|
19 |
+
self.api_url = "https://api.deepseek.com/v1/chat/completions"
|
20 |
+
|
21 |
+
# Story templates for different age groups
|
22 |
+
self.templates = {
|
23 |
+
"6-8": "Write a simple and fun short story for a 6-8 year old child about this image: ",
|
24 |
+
"9-12": "Write an engaging short story with a simple moral for a 9-12 year old about this image: "
|
25 |
+
}
|
26 |
+
|
27 |
+
# Themes and associated vocabulary to enhance stories
|
28 |
+
self.themes = {
|
29 |
+
"adventure": ["journey", "discover", "explore", "treasure", "map"],
|
30 |
+
"fantasy": ["magic", "dragon", "wizard", "fairy", "kingdom"],
|
31 |
+
"animals": ["forest", "pets", "wildlife", "jungle", "farm"],
|
32 |
+
"friendship": ["friends", "sharing", "helping", "together", "team"],
|
33 |
+
"science": ["experiment", "invention", "discovery", "robot", "space"]
|
34 |
+
}
|
35 |
+
|
36 |
+
def generate(self, image_file, age_group="6-12", theme=None):
|
37 |
+
"""
|
38 |
+
Generate a story based on the input image using DeepSeek API
|
39 |
+
|
40 |
+
Args:
|
41 |
+
image_file: The uploaded image file
|
42 |
+
age_group: Age group target ("6-8" or "9-12")
|
43 |
+
theme: Optional theme to influence the story
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
str: A generated story suitable for the specified age group
|
47 |
+
"""
|
48 |
+
try:
|
49 |
+
# Process the image
|
50 |
+
image = Image.open(image_file).convert('RGB')
|
51 |
+
|
52 |
+
# Resize image if too large
|
53 |
+
max_size = 1024
|
54 |
+
if max(image.size) > max_size:
|
55 |
+
ratio = max_size / max(image.size)
|
56 |
+
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
57 |
+
image = image.resize(new_size, Image.LANCZOS)
|
58 |
+
|
59 |
+
# Convert image to base64
|
60 |
+
buffered = BytesIO()
|
61 |
+
image.save(buffered, format="JPEG", quality=85)
|
62 |
+
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
63 |
+
|
64 |
+
# Determine the template based on age group
|
65 |
+
template = self.templates.get(age_group, self.templates["6-8"])
|
66 |
+
|
67 |
+
# Enhance the prompt with theme if provided
|
68 |
+
if theme and theme in self.themes:
|
69 |
+
theme_words = ", ".join(self.themes[theme][:3]) # Use first 3 theme words
|
70 |
+
prompt = f"{template} Please include elements of {theme} like {theme_words}."
|
71 |
+
else:
|
72 |
+
prompt = template
|
73 |
+
|
74 |
+
# Create the API payload
|
75 |
+
payload = {
|
76 |
+
"model": "deepseek-vision",
|
77 |
+
"messages": [
|
78 |
+
{
|
79 |
+
"role": "user",
|
80 |
+
"content": [
|
81 |
+
{
|
82 |
+
"type": "text",
|
83 |
+
"text": prompt
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"type": "image_url",
|
87 |
+
"image_url": {
|
88 |
+
"url": f"data:image/jpeg;base64,{img_base64}"
|
89 |
+
}
|
90 |
+
}
|
91 |
+
]
|
92 |
+
}
|
93 |
+
],
|
94 |
+
"max_tokens": 1000,
|
95 |
+
"temperature": 0.7
|
96 |
+
}
|
97 |
+
|
98 |
+
# Set the headers with authorization
|
99 |
+
headers = {
|
100 |
+
"Content-Type": "application/json",
|
101 |
+
"Authorization": f"Bearer {self.api_key}"
|
102 |
+
}
|
103 |
+
|
104 |
+
# Make the API request
|
105 |
+
response = requests.post(self.api_url, headers=headers, json=payload)
|
106 |
+
response.raise_for_status()
|
107 |
+
|
108 |
+
# Parse the API response
|
109 |
+
result = response.json()
|
110 |
+
story = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
111 |
+
|
112 |
+
if not story:
|
113 |
+
raise ValueError("No story was generated from the API")
|
114 |
+
|
115 |
+
# Format the story
|
116 |
+
story = self._format_story(story, age_group)
|
117 |
+
|
118 |
+
return story
|
119 |
+
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error generating story: {str(e)}")
|
122 |
+
raise e
|
123 |
+
|
124 |
+
def _format_story(self, story, age_group):
|
125 |
+
"""Format the story based on age group"""
|
126 |
+
# Add paragraph breaks every 2-3 sentences
|
127 |
+
sentences = story.split('.')
|
128 |
+
formatted_text = ""
|
129 |
+
|
130 |
+
for i, sentence in enumerate(sentences):
|
131 |
+
if sentence.strip(): # Skip empty sentences
|
132 |
+
formatted_text += sentence.strip() + "."
|
133 |
+
if i % 3 == 2: # Add paragraph break every 3 sentences
|
134 |
+
formatted_text += "\n\n"
|
135 |
+
|
136 |
+
# For younger kids, keep it shorter by taking just the first few paragraphs
|
137 |
+
if age_group == "6-8":
|
138 |
+
paragraphs = formatted_text.split("\n\n")
|
139 |
+
if len(paragraphs) > 3:
|
140 |
+
formatted_text = "\n\n".join(paragraphs[:3])
|
141 |
+
|
142 |
+
return formatted_text
|
test_server.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
from PIL import Image
|
5 |
+
import json
|
6 |
+
import base64
|
7 |
+
from io import BytesIO
|
8 |
+
|
9 |
+
def test_health_endpoint(base_url):
|
10 |
+
"""Test if the Gradio server is running"""
|
11 |
+
try:
|
12 |
+
response = requests.get(f"{base_url}")
|
13 |
+
print(f"Server health check: {response.status_code}")
|
14 |
+
assert response.status_code == 200
|
15 |
+
return response.status_code == 200
|
16 |
+
except Exception as e:
|
17 |
+
print(f"Error connecting to server: {str(e)}")
|
18 |
+
return False
|
19 |
+
|
20 |
+
def test_story_generation(base_url, image_path, age_group="6-8", theme="adventure"):
|
21 |
+
"""Test story generation with an image using the Gradio API"""
|
22 |
+
if not os.path.exists(image_path):
|
23 |
+
print(f"Error: Image file not found at {image_path}")
|
24 |
+
return False
|
25 |
+
|
26 |
+
# Ensure the image can be opened
|
27 |
+
try:
|
28 |
+
img = Image.open(image_path)
|
29 |
+
img_format = img.format if img.format else "JPEG"
|
30 |
+
img_buffer = BytesIO()
|
31 |
+
img.save(img_buffer, format=img_format)
|
32 |
+
img_bytes = img_buffer.getvalue()
|
33 |
+
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
34 |
+
except Exception as e:
|
35 |
+
print(f"Error processing image: {str(e)}")
|
36 |
+
return False
|
37 |
+
|
38 |
+
# Prepare the API request for Gradio
|
39 |
+
url = f"{base_url}/api/predict"
|
40 |
+
|
41 |
+
# Convert theme to the correct format for Gradio
|
42 |
+
if theme.lower() == "none":
|
43 |
+
theme = "None"
|
44 |
+
else:
|
45 |
+
theme = theme.capitalize()
|
46 |
+
|
47 |
+
# Build the payload
|
48 |
+
payload = {
|
49 |
+
"data": [
|
50 |
+
f"data:image/{img_format.lower()};base64,{img_base64}",
|
51 |
+
age_group,
|
52 |
+
theme
|
53 |
+
]
|
54 |
+
}
|
55 |
+
|
56 |
+
print(f"Sending request to {url}...")
|
57 |
+
print(f"Age group: {age_group}")
|
58 |
+
print(f"Theme: {theme}")
|
59 |
+
|
60 |
+
try:
|
61 |
+
response = requests.post(url, json=payload)
|
62 |
+
print(f"Status code: {response.status_code}")
|
63 |
+
|
64 |
+
if response.status_code == 200:
|
65 |
+
result = response.json()
|
66 |
+
story = result.get('data', [''])[0]
|
67 |
+
|
68 |
+
print("\nGenerated Story:")
|
69 |
+
print("=" * 50)
|
70 |
+
print(story)
|
71 |
+
print("=" * 50)
|
72 |
+
return True
|
73 |
+
else:
|
74 |
+
print(f"Error: {response.text}")
|
75 |
+
return False
|
76 |
+
except Exception as e:
|
77 |
+
print(f"Error during request: {str(e)}")
|
78 |
+
return False
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
parser = argparse.ArgumentParser(description='Test the story generation server')
|
82 |
+
parser.add_argument('--url', default='http://localhost:7860', help='Base URL of the Gradio server')
|
83 |
+
parser.add_argument('--image', required=True, help='Path to the test image')
|
84 |
+
parser.add_argument('--age', default='6-8', choices=['6-8', '9-12'], help='Age group target')
|
85 |
+
parser.add_argument('--theme', default='adventure',
|
86 |
+
choices=['none', 'adventure', 'fantasy', 'animals', 'friendship', 'science'],
|
87 |
+
help='Story theme')
|
88 |
+
|
89 |
+
args = parser.parse_args()
|
90 |
+
|
91 |
+
print(f"Testing Gradio server at {args.url}")
|
92 |
+
|
93 |
+
if test_health_endpoint(args.url):
|
94 |
+
print("\nServer is running!")
|
95 |
+
print("\nTesting story generation...")
|
96 |
+
if test_story_generation(args.url, args.image, args.age, args.theme):
|
97 |
+
print("\nStory generation test passed!")
|
98 |
+
else:
|
99 |
+
print("\nStory generation test failed!")
|
100 |
+
else:
|
101 |
+
print("\nServer health check failed, server may not be running correctly.")
|