walker11 commited on
Commit
1d96975
·
verified ·
1 Parent(s): 9469c91

Upload 8 files

Browse files
Files changed (8) hide show
  1. README.md +115 -13
  2. app.py +80 -0
  3. evaluate_model.py +151 -0
  4. huggingface-metadata.json +12 -0
  5. requirements.txt +8 -0
  6. run_server.py +51 -0
  7. story_generator.py +142 -0
  8. test_server.py +101 -0
README.md CHANGED
@@ -1,13 +1,115 @@
1
- ---
2
- title: RawiKids
3
- emoji: 💻
4
- colorFrom: pink
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.34.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Rawi Kids Vision-Language Model
2
+
3
+ A vision-language model that generates engaging short stories for children (ages 6-12) based on images. This project is designed to be integrated with the Rawi Kids Flutter application and uses the DeepSeek Vision API.
4
+
5
+ ## Features
6
+
7
+ - Generate age-appropriate stories from images
8
+ - Support for different age groups (6-8 and 9-12 years)
9
+ - Optional themes to influence story generation (adventure, fantasy, animals, etc.)
10
+ - Gradio web interface for easy testing
11
+ - Integration with Flutter app
12
+ - Uses DeepSeek Vision API (no local model needed)
13
+
14
+ ## Demo
15
+
16
+ This model can be tested using the Gradio web interface included in the project.
17
+
18
+ ## Setup and Installation
19
+
20
+ ### Prerequisites
21
+
22
+ - Python 3.8 or higher
23
+ - pip (Python package manager)
24
+ - Virtual environment (recommended)
25
+ - DeepSeek API Key
26
+
27
+ ### Getting a DeepSeek API Key
28
+
29
+ 1. Visit the [DeepSeek website](https://www.deepseek.com/) and sign up for an account
30
+ 2. Navigate to your API settings page to obtain an API key
31
+ 3. Copy the API key for use in the next steps
32
+
33
+ ### Installation
34
+
35
+ 1. Clone this repository
36
+ ```
37
+ git clone <repository-url>
38
+ cd rawi-kids-vlm
39
+ ```
40
+
41
+ 2. Create and activate a virtual environment
42
+ ```
43
+ python -m venv venv
44
+ # On Windows
45
+ venv\Scripts\activate
46
+ # On macOS/Linux
47
+ source venv/bin/activate
48
+ ```
49
+
50
+ 3. Install the required packages
51
+ ```
52
+ pip install -r requirements.txt
53
+ ```
54
+
55
+ 4. Create a `.env` file and add your DeepSeek API key
56
+ ```
57
+ echo "DEEPSEEK_API_KEY=your_api_key_here" > .env
58
+ ```
59
+
60
+ 5. Run the Gradio app
61
+ ```
62
+ python app.py
63
+ ```
64
+
65
+ The interface will be available at http://localhost:7860
66
+
67
+ ## Using the Interface
68
+
69
+ 1. Upload an image using the file uploader
70
+ 2. Select the target age group (6-8 or 9-12 years)
71
+ 3. Choose a story theme (optional)
72
+ 4. Click "Generate Story"
73
+ 5. The model will analyze the image and generate an age-appropriate story
74
+
75
+ ## Flutter Integration
76
+
77
+ See the `test_server.py` file for examples of how to integrate with your Flutter app. You'll need to implement an API client in your Flutter app that sends images to this service and receives the generated stories.
78
+
79
+ ## Testing
80
+
81
+ You can test the model using the provided test script:
82
+
83
+ ```
84
+ python test_server.py --url http://localhost:7860 --image path/to/test_image.jpg
85
+ ```
86
+
87
+ ## Evaluation
88
+
89
+ For more detailed evaluation of the model's performance, use the evaluation script:
90
+
91
+ ```
92
+ python evaluate_model.py --images test_images --output evaluation_results.json
93
+ ```
94
+
95
+ ## Deploying to Hugging Face Spaces
96
+
97
+ This project is designed to work with Hugging Face Spaces, which provides free hosting for machine learning demos.
98
+
99
+ 1. Create a new Space on Hugging Face
100
+ 2. Select "Gradio" as the SDK
101
+ 3. Push this repository to the Space
102
+ 4. Add your DeepSeek API key as a secret in the Space configuration
103
+ 5. The app will automatically deploy and be available at your Space URL
104
+
105
+ ## Important Note on API Usage
106
+
107
+ The DeepSeek API is a commercial service and may have usage limits or costs associated with it. Make sure to check their pricing and terms of service to understand any potential costs for your usage level.
108
+
109
+ ## License
110
+
111
+ [Add your license information here]
112
+
113
+ ## Contact
114
+
115
+ [Add your contact information here]
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from story_generator import StoryGenerator
3
+ import tempfile
4
+ import os
5
+
6
+ # Initialize the story generator
7
+ story_generator = StoryGenerator()
8
+
9
+ # Define the available themes
10
+ THEMES = ["None", "Adventure", "Fantasy", "Animals", "Friendship", "Science"]
11
+
12
+ def generate_story(image, age_group, theme):
13
+ """
14
+ Generate a story from an image using the story generator
15
+
16
+ Args:
17
+ image: The uploaded image
18
+ age_group: The target age group
19
+ theme: The story theme
20
+
21
+ Returns:
22
+ str: The generated story
23
+ """
24
+ # Save the image to a temporary file
25
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp:
26
+ image.save(temp.name)
27
+ temp_filename = temp.name
28
+
29
+ try:
30
+ # Process the theme (convert "None" to None)
31
+ processed_theme = None if theme == "None" else theme.lower()
32
+
33
+ # Open the image file and generate the story
34
+ with open(temp_filename, 'rb') as img_file:
35
+ story = story_generator.generate(img_file, age_group, processed_theme)
36
+
37
+ return story
38
+ except Exception as e:
39
+ return f"Error generating story: {str(e)}"
40
+ finally:
41
+ # Clean up the temporary file
42
+ if os.path.exists(temp_filename):
43
+ os.unlink(temp_filename)
44
+
45
+ # Create the Gradio interface
46
+ with gr.Blocks(title="Rawi Kids Story Generator") as demo:
47
+ gr.Markdown("# Rawi Kids Story Generator")
48
+ gr.Markdown("Upload an image and get a story for kids!")
49
+
50
+ with gr.Row():
51
+ with gr.Column(scale=1):
52
+ # Input components
53
+ image_input = gr.Image(type="pil", label="Upload Image")
54
+ age_group = gr.Radio(choices=["6-8", "9-12"], value="6-8", label="Age Group (years)")
55
+ theme = gr.Dropdown(choices=THEMES, value="None", label="Story Theme")
56
+ submit_btn = gr.Button("Generate Story", variant="primary")
57
+
58
+ with gr.Column(scale=1):
59
+ # Output component
60
+ story_output = gr.Textbox(label="Generated Story", lines=10)
61
+
62
+ # Set up the button click event
63
+ submit_btn.click(
64
+ fn=generate_story,
65
+ inputs=[image_input, age_group, theme],
66
+ outputs=story_output
67
+ )
68
+
69
+ gr.Markdown("""
70
+ ### How it works
71
+
72
+ 1. Upload a picture or take a photo
73
+ 2. Select the age group (6-8 or 9-12 years)
74
+ 3. Choose a theme for the story (optional)
75
+ 4. Click "Generate Story"
76
+ 5. The AI will analyze the image and create a story for kids!
77
+ """)
78
+
79
+ # For Hugging Face Spaces
80
+ demo.launch(server_name="0.0.0.0", server_port=7860)
evaluate_model.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import argparse
4
+ from PIL import Image
5
+ import glob
6
+ import json
7
+ import time
8
+ from story_generator import StoryGenerator
9
+ from dotenv import load_dotenv
10
+ import sys
11
+
12
+ # Load environment variables for API key
13
+ load_dotenv()
14
+
15
+ class ModelEvaluator:
16
+ def __init__(self, images_dir, output_file):
17
+ """Initialize the evaluator with paths to images and output file"""
18
+ self.images_dir = images_dir
19
+ self.output_file = output_file
20
+
21
+ # Check for API key
22
+ if not os.getenv("DEEPSEEK_API_KEY"):
23
+ print("ERROR: DEEPSEEK_API_KEY environment variable not found.")
24
+ print("Please set your DeepSeek API key using:")
25
+ print(" - Create a .env file with DEEPSEEK_API_KEY=your_key_here")
26
+ print(" - Or set the environment variable directly")
27
+ sys.exit(1)
28
+
29
+ # Initialize story generator
30
+ self.generator = StoryGenerator()
31
+
32
+ # Create output directory if it doesn't exist
33
+ os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True)
34
+
35
+ def evaluate_all(self, limit=None):
36
+ """Evaluate the model on all images in the directory"""
37
+ image_files = glob.glob(os.path.join(self.images_dir, "*.jpg")) + \
38
+ glob.glob(os.path.join(self.images_dir, "*.jpeg")) + \
39
+ glob.glob(os.path.join(self.images_dir, "*.png"))
40
+
41
+ # Limit the number of images if specified (to control API usage)
42
+ if limit and limit > 0:
43
+ image_files = image_files[:limit]
44
+
45
+ print(f"Found {len(image_files)} images for evaluation")
46
+ print(f"NOTE: Using DeepSeek API - API call charges may apply")
47
+
48
+ results = []
49
+
50
+ for idx, img_path in enumerate(image_files):
51
+ print(f"\nProcessing image {idx + 1}/{len(image_files)}: {os.path.basename(img_path)}")
52
+
53
+ # Test with different age groups and themes (use fewer combinations to limit API calls)
54
+ for age_group in ["6-8", "9-12"]:
55
+ # Limit theme testing to save on API calls
56
+ for theme in [None, "adventure"]:
57
+ theme_str = theme if theme else "none"
58
+ print(f" Generating story for age group: {age_group}, theme: {theme_str}")
59
+
60
+ try:
61
+ start_time = time.time()
62
+ with open(img_path, 'rb') as img_file:
63
+ story = self.generator.generate(img_file, age_group, theme)
64
+ generation_time = time.time() - start_time
65
+
66
+ # Record the result
67
+ result = {
68
+ "image_path": img_path,
69
+ "age_group": age_group,
70
+ "theme": theme_str,
71
+ "generation_time_seconds": round(generation_time, 2),
72
+ "story_length_chars": len(story),
73
+ "story_words": len(story.split()),
74
+ "story": story
75
+ }
76
+
77
+ results.append(result)
78
+
79
+ # Print summary
80
+ print(f" Time: {result['generation_time_seconds']:.2f}s, "
81
+ f"Words: {result['story_words']}")
82
+
83
+ except Exception as e:
84
+ print(f" Error generating story: {str(e)}")
85
+ results.append({
86
+ "image_path": img_path,
87
+ "age_group": age_group,
88
+ "theme": theme_str,
89
+ "error": str(e)
90
+ })
91
+
92
+ # Save all results to file
93
+ with open(self.output_file, 'w') as f:
94
+ json.dump(results, f, indent=2)
95
+
96
+ print(f"\nEvaluation complete. Results saved to {self.output_file}")
97
+ return results
98
+
99
+ def print_summary(self, results):
100
+ """Print a summary of the evaluation results"""
101
+ if not results:
102
+ print("No results to summarize")
103
+ return
104
+
105
+ successful_generations = [r for r in results if "error" not in r]
106
+ error_generations = [r for r in results if "error" in r]
107
+
108
+ print(f"\nEvaluation Summary:")
109
+ print(f" Total images processed: {len(set([r['image_path'] for r in results]))}")
110
+ print(f" Total story generations: {len(results)}")
111
+ print(f" Successful generations: {len(successful_generations)}")
112
+ print(f" Failed generations: {len(error_generations)}")
113
+
114
+ if successful_generations:
115
+ avg_time = sum([r["generation_time_seconds"] for r in successful_generations]) / len(successful_generations)
116
+ avg_words = sum([r["story_words"] for r in successful_generations]) / len(successful_generations)
117
+ print(f" Average generation time: {avg_time:.2f} seconds")
118
+ print(f" Average story length: {avg_words:.1f} words")
119
+
120
+ # Analysis by age group
121
+ age_groups = ["6-8", "9-12"]
122
+ for age_group in age_groups:
123
+ age_results = [r for r in successful_generations if r["age_group"] == age_group]
124
+ if age_results:
125
+ avg_words = sum([r["story_words"] for r in age_results]) / len(age_results)
126
+ print(f" Age group {age_group}: {len(age_results)} stories, avg {avg_words:.1f} words")
127
+
128
+ # Analysis by theme
129
+ themes = ["none", "adventure", "fantasy", "animals"]
130
+ for theme in themes:
131
+ theme_results = [r for r in successful_generations if r["theme"] == theme]
132
+ if theme_results:
133
+ avg_words = sum([r["story_words"] for r in theme_results]) / len(theme_results)
134
+ print(f" Theme {theme}: {len(theme_results)} stories, avg {avg_words:.1f} words")
135
+
136
+
137
+ def main():
138
+ parser = argparse.ArgumentParser(description='Evaluate the story generation model')
139
+ parser.add_argument('--images', default='test_images', help='Directory containing test images')
140
+ parser.add_argument('--output', default='evaluation_results.json', help='Output file for results')
141
+ parser.add_argument('--limit', type=int, default=2, help='Limit the number of images to process (to control API usage)')
142
+
143
+ args = parser.parse_args()
144
+
145
+ evaluator = ModelEvaluator(args.images, args.output)
146
+ results = evaluator.evaluate_all(limit=args.limit)
147
+ evaluator.print_summary(results)
148
+
149
+
150
+ if __name__ == '__main__':
151
+ main()
huggingface-metadata.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "title": "Rawi Kids Story Generator",
3
+ "emoji": "📚",
4
+ "colorFrom": "blue",
5
+ "colorTo": "purple",
6
+ "sdk": "gradio",
7
+ "sdk_version": "3.50.2",
8
+ "python_version": "3.10",
9
+ "app_file": "app.py",
10
+ "pinned": false,
11
+ "license": "mit"
12
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ flask==2.0.1
2
+ pillow==9.5.0
3
+ python-dotenv==1.0.0
4
+ flask-cors==3.0.10
5
+ gunicorn==20.1.0
6
+ numpy==1.24.3
7
+ requests==2.31.0
8
+ gradio==3.50.2
run_server.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import sys
4
+ import argparse
5
+ import subprocess
6
+ from dotenv import load_dotenv
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(description='Run the Vision Language Model server')
10
+ parser.add_argument('--port', type=int, help='Port to run the server on')
11
+ parser.add_argument('--debug', action='store_true', help='Run in debug mode')
12
+ parser.add_argument('--host', default='0.0.0.0', help='Host to run the server on')
13
+ parser.add_argument('--workers', type=int, default=1, help='Number of Gunicorn workers')
14
+ parser.add_argument('--use-gunicorn', action='store_true', help='Use Gunicorn for production')
15
+
16
+ args = parser.parse_args()
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+
21
+ # Set environment variables from command line arguments
22
+ if args.port:
23
+ os.environ['PORT'] = str(args.port)
24
+
25
+ if args.debug:
26
+ os.environ['DEBUG'] = 'True'
27
+
28
+ port = int(os.environ.get('PORT', 5000))
29
+ debug = os.environ.get('DEBUG', 'False').lower() == 'true'
30
+
31
+ print(f"Starting server on {args.host}:{port}")
32
+ print(f"Debug mode: {debug}")
33
+
34
+ if args.use_gunicorn:
35
+ # Use Gunicorn for production
36
+ cmd = [
37
+ 'gunicorn',
38
+ '--bind', f"{args.host}:{port}",
39
+ '--workers', str(args.workers),
40
+ 'app:app'
41
+ ]
42
+ print(f"Running with gunicorn: {' '.join(cmd)}")
43
+ subprocess.call(cmd)
44
+ else:
45
+ # Use Flask's built-in server
46
+ from app import app
47
+ app.run(host=args.host, port=port, debug=debug)
48
+
49
+
50
+ if __name__ == '__main__':
51
+ main()
story_generator.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import base64
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import json
7
+ from dotenv import load_dotenv
8
+
9
+ # Load environment variables for API keys
10
+ load_dotenv()
11
+
12
+ class StoryGenerator:
13
+ def __init__(self):
14
+ """Initialize the story generator with DeepSeek API configuration"""
15
+ self.api_key = os.getenv("DEEPSEEK_API_KEY")
16
+ if not self.api_key:
17
+ print("Warning: DEEPSEEK_API_KEY not found in environment variables. Please set it.")
18
+
19
+ self.api_url = "https://api.deepseek.com/v1/chat/completions"
20
+
21
+ # Story templates for different age groups
22
+ self.templates = {
23
+ "6-8": "Write a simple and fun short story for a 6-8 year old child about this image: ",
24
+ "9-12": "Write an engaging short story with a simple moral for a 9-12 year old about this image: "
25
+ }
26
+
27
+ # Themes and associated vocabulary to enhance stories
28
+ self.themes = {
29
+ "adventure": ["journey", "discover", "explore", "treasure", "map"],
30
+ "fantasy": ["magic", "dragon", "wizard", "fairy", "kingdom"],
31
+ "animals": ["forest", "pets", "wildlife", "jungle", "farm"],
32
+ "friendship": ["friends", "sharing", "helping", "together", "team"],
33
+ "science": ["experiment", "invention", "discovery", "robot", "space"]
34
+ }
35
+
36
+ def generate(self, image_file, age_group="6-12", theme=None):
37
+ """
38
+ Generate a story based on the input image using DeepSeek API
39
+
40
+ Args:
41
+ image_file: The uploaded image file
42
+ age_group: Age group target ("6-8" or "9-12")
43
+ theme: Optional theme to influence the story
44
+
45
+ Returns:
46
+ str: A generated story suitable for the specified age group
47
+ """
48
+ try:
49
+ # Process the image
50
+ image = Image.open(image_file).convert('RGB')
51
+
52
+ # Resize image if too large
53
+ max_size = 1024
54
+ if max(image.size) > max_size:
55
+ ratio = max_size / max(image.size)
56
+ new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
57
+ image = image.resize(new_size, Image.LANCZOS)
58
+
59
+ # Convert image to base64
60
+ buffered = BytesIO()
61
+ image.save(buffered, format="JPEG", quality=85)
62
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
63
+
64
+ # Determine the template based on age group
65
+ template = self.templates.get(age_group, self.templates["6-8"])
66
+
67
+ # Enhance the prompt with theme if provided
68
+ if theme and theme in self.themes:
69
+ theme_words = ", ".join(self.themes[theme][:3]) # Use first 3 theme words
70
+ prompt = f"{template} Please include elements of {theme} like {theme_words}."
71
+ else:
72
+ prompt = template
73
+
74
+ # Create the API payload
75
+ payload = {
76
+ "model": "deepseek-vision",
77
+ "messages": [
78
+ {
79
+ "role": "user",
80
+ "content": [
81
+ {
82
+ "type": "text",
83
+ "text": prompt
84
+ },
85
+ {
86
+ "type": "image_url",
87
+ "image_url": {
88
+ "url": f"data:image/jpeg;base64,{img_base64}"
89
+ }
90
+ }
91
+ ]
92
+ }
93
+ ],
94
+ "max_tokens": 1000,
95
+ "temperature": 0.7
96
+ }
97
+
98
+ # Set the headers with authorization
99
+ headers = {
100
+ "Content-Type": "application/json",
101
+ "Authorization": f"Bearer {self.api_key}"
102
+ }
103
+
104
+ # Make the API request
105
+ response = requests.post(self.api_url, headers=headers, json=payload)
106
+ response.raise_for_status()
107
+
108
+ # Parse the API response
109
+ result = response.json()
110
+ story = result.get("choices", [{}])[0].get("message", {}).get("content", "")
111
+
112
+ if not story:
113
+ raise ValueError("No story was generated from the API")
114
+
115
+ # Format the story
116
+ story = self._format_story(story, age_group)
117
+
118
+ return story
119
+
120
+ except Exception as e:
121
+ print(f"Error generating story: {str(e)}")
122
+ raise e
123
+
124
+ def _format_story(self, story, age_group):
125
+ """Format the story based on age group"""
126
+ # Add paragraph breaks every 2-3 sentences
127
+ sentences = story.split('.')
128
+ formatted_text = ""
129
+
130
+ for i, sentence in enumerate(sentences):
131
+ if sentence.strip(): # Skip empty sentences
132
+ formatted_text += sentence.strip() + "."
133
+ if i % 3 == 2: # Add paragraph break every 3 sentences
134
+ formatted_text += "\n\n"
135
+
136
+ # For younger kids, keep it shorter by taking just the first few paragraphs
137
+ if age_group == "6-8":
138
+ paragraphs = formatted_text.split("\n\n")
139
+ if len(paragraphs) > 3:
140
+ formatted_text = "\n\n".join(paragraphs[:3])
141
+
142
+ return formatted_text
test_server.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ import argparse
4
+ from PIL import Image
5
+ import json
6
+ import base64
7
+ from io import BytesIO
8
+
9
+ def test_health_endpoint(base_url):
10
+ """Test if the Gradio server is running"""
11
+ try:
12
+ response = requests.get(f"{base_url}")
13
+ print(f"Server health check: {response.status_code}")
14
+ assert response.status_code == 200
15
+ return response.status_code == 200
16
+ except Exception as e:
17
+ print(f"Error connecting to server: {str(e)}")
18
+ return False
19
+
20
+ def test_story_generation(base_url, image_path, age_group="6-8", theme="adventure"):
21
+ """Test story generation with an image using the Gradio API"""
22
+ if not os.path.exists(image_path):
23
+ print(f"Error: Image file not found at {image_path}")
24
+ return False
25
+
26
+ # Ensure the image can be opened
27
+ try:
28
+ img = Image.open(image_path)
29
+ img_format = img.format if img.format else "JPEG"
30
+ img_buffer = BytesIO()
31
+ img.save(img_buffer, format=img_format)
32
+ img_bytes = img_buffer.getvalue()
33
+ img_base64 = base64.b64encode(img_bytes).decode('utf-8')
34
+ except Exception as e:
35
+ print(f"Error processing image: {str(e)}")
36
+ return False
37
+
38
+ # Prepare the API request for Gradio
39
+ url = f"{base_url}/api/predict"
40
+
41
+ # Convert theme to the correct format for Gradio
42
+ if theme.lower() == "none":
43
+ theme = "None"
44
+ else:
45
+ theme = theme.capitalize()
46
+
47
+ # Build the payload
48
+ payload = {
49
+ "data": [
50
+ f"data:image/{img_format.lower()};base64,{img_base64}",
51
+ age_group,
52
+ theme
53
+ ]
54
+ }
55
+
56
+ print(f"Sending request to {url}...")
57
+ print(f"Age group: {age_group}")
58
+ print(f"Theme: {theme}")
59
+
60
+ try:
61
+ response = requests.post(url, json=payload)
62
+ print(f"Status code: {response.status_code}")
63
+
64
+ if response.status_code == 200:
65
+ result = response.json()
66
+ story = result.get('data', [''])[0]
67
+
68
+ print("\nGenerated Story:")
69
+ print("=" * 50)
70
+ print(story)
71
+ print("=" * 50)
72
+ return True
73
+ else:
74
+ print(f"Error: {response.text}")
75
+ return False
76
+ except Exception as e:
77
+ print(f"Error during request: {str(e)}")
78
+ return False
79
+
80
+ if __name__ == "__main__":
81
+ parser = argparse.ArgumentParser(description='Test the story generation server')
82
+ parser.add_argument('--url', default='http://localhost:7860', help='Base URL of the Gradio server')
83
+ parser.add_argument('--image', required=True, help='Path to the test image')
84
+ parser.add_argument('--age', default='6-8', choices=['6-8', '9-12'], help='Age group target')
85
+ parser.add_argument('--theme', default='adventure',
86
+ choices=['none', 'adventure', 'fantasy', 'animals', 'friendship', 'science'],
87
+ help='Story theme')
88
+
89
+ args = parser.parse_args()
90
+
91
+ print(f"Testing Gradio server at {args.url}")
92
+
93
+ if test_health_endpoint(args.url):
94
+ print("\nServer is running!")
95
+ print("\nTesting story generation...")
96
+ if test_story_generation(args.url, args.image, args.age, args.theme):
97
+ print("\nStory generation test passed!")
98
+ else:
99
+ print("\nStory generation test failed!")
100
+ else:
101
+ print("\nServer health check failed, server may not be running correctly.")