RawiKids / app.py
walker11's picture
Update app.py
017a7bc verified
import os
import sys
import base64
import random
import requests
import tempfile
import gradio as gr
import numpy as np
from io import BytesIO
from PIL import Image
from dotenv import load_dotenv
from gradio_client import Client
# Load environment variables
load_dotenv()
# ============================================================
# CORE CONFIGURATION
# ============================================================
# Check for required API keys
required_keys = {
"OPENROUTER_API_KEY": "OpenRouter for image recognition",
"DEEPSEEK_API_KEY": "DeepSeek for story generation",
"HF_ACCESS_TOKEN": "HuggingFace for text-to-speech"
}
missing_keys = []
for key, purpose in required_keys.items():
if not os.getenv(key):
missing_keys.append(f"{key} ({purpose})")
if missing_keys:
print("ERROR: The following required API keys were not found:")
for key in missing_keys:
print(f" - {key}")
print("\nPlease set these environment variables in a .env file or directly.")
sys.exit(1)
# Available theme options (will be randomly selected)
AVAILABLE_THEMES = [
"adventure", "fantasy", "animals", "friendship", "science",
"humor", "history", "arts", "space",
"sports", "heroes", "underwater", "nature", "technology",
"family", "school", "seasons", "travel", "food"
]
# Available voices and emotions for TTS
VOICES = ["nova", "alloy", "echo", "fable", "onyx", "shimmer"]
EMOTIONS = ["neutral", "happy", "sad", "anger", "story!!"]
# ============================================================
# TTS CLIENT IMPLEMENTATION
# ============================================================
# Available voices and emotions for TTS
# Using only "shimmer" voice with a specific storytelling style for kids
STORYTELLING_STYLE = "Affect: A gentle, curious narrator with a British accent, guiding a magical, child-friendly adventure through a fairy tale world. Tone: Magical, warm, and inviting, creating a sense of wonder and excitement for young listeners. Pacing: Steady and measured, with slight pauses to emphasize magical moments and maintain the storytelling flow. Emotion: Wonder, curiosity, and a sense of adventure, with a lighthearted and positive vibe throughout. Pronunciation: Clear and precise, with an emphasis on storytelling, ensuring the words are easy to follow and enchanting to listen to."
class TTSClient:
"""Client for the private Hugging Face text-to-speech service"""
def __init__(self):
"""Initialize the TTS client with the HuggingFace access token"""
self.hf_token = os.getenv("HF_ACCESS_TOKEN")
if not self.hf_token:
self.error_message = "HF_ACCESS_TOKEN not found in environment variables"
print(self.error_message)
return
self.password = "YSF9580" # Fixed password for the private space
self.space_id = "KindSynapse/Youssef-Ahmed-Private-Text-To-Speech-Unlimited"
# Fixed voice and storytelling style for kids
self.voice = "shimmer"
self.storytelling_style = STORYTELLING_STYLE
# Client is initialized only when needed
self.client = None
self.error_message = None
def _initialize_client(self):
"""Initialize the client when needed"""
if self.client is None:
try:
print(f"Initializing TTS client for space: {self.space_id}")
self.client = Client(
src=self.space_id,
hf_token=self.hf_token
)
return True, "Client initialized successfully"
except Exception as e:
error_msg = str(e)
self.error_message = f"Failed to initialize TTS client: {error_msg}"
print(self.error_message)
if "SecurityError" in error_msg or "cross-origin" in error_msg:
return False, "Cross-origin error detected"
else:
return False, self.error_message
return True, "Client already initialized"
def generate_audio(self, text, use_random_seed=True, seed=12345):
"""
Generate audio from text using the private Hugging Face TTS service with fixed voice and style
Args:
text (str): The text to convert to speech
Returns:
str or tuple: Path to the generated audio file if successful, or (None, error_message) if there's an error
"""
# Initialize client if not already done
client_ok, client_msg = self._initialize_client()
if not client_ok:
return None, client_msg
if self.client is None:
return None, "TTS client could not be initialized"
# Call the TTS API
try:
result = self.client.predict(
password=self.password,
prompt=text,
voice=self.voice,
emotion=self.storytelling_style,
use_random_seed=use_random_seed,
specific_seed=seed,
api_name="/text_to_speech_app"
)
# Handle different result types
if isinstance(result, tuple):
if len(result) >= 1:
if isinstance(result[0], str):
file_path = result[0]
if os.path.exists(file_path):
return file_path
if len(result) >= 2 and result[0] is None:
return None, result[1]
return None, "Invalid response from TTS service"
elif isinstance(result, str):
if os.path.exists(result):
return result
return None, f"Audio file not found at {result}"
else:
return None, f"Unexpected result type: {type(result)}"
except Exception as e:
return None, f"Error generating audio: {str(e)}"
# ============================================================
# STORY GENERATOR IMPLEMENTATION
# ============================================================
class StoryGenerator:
def __init__(self):
"""Initialize the story generator with APIs"""
# Get API keys from environment variables
self.openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
self.deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
# API endpoints
self.openrouter_api_url = "https://openrouter.ai/api/v1/chat/completions"
self.deepseek_api_url = "https://api.deepseek.com/v1/chat/completions"
# Site information for OpenRouter
self.site_url = os.getenv("SITE_URL", "https://rawi-kids-stories.com")
self.site_name = os.getenv("SITE_NAME", "Rawi Kids Stories")
# OpenRouter model for vision capabilities
self.vision_model = "openai/gpt-4.1"
# Story template for kids ages 6-10
self.template = "اكتب قصة بسيطة وممتعة لطفل يتراوح عمره بين 6-10 سنوات عن: {}. حافظ على أن تكون القصة حوالي 300-400 كلمة بلغة بسيطة ورسالة إيجابية. اكتب القصة باللغة العربية الفصحى المناسبة للأطفال. تأكد من أن القصة تحترم الثقافة والأخلاق العربية والإسلامية. في نهاية القصة، اختم بعبارة واضحة تبدأ بـ 'ومن هذه القصة نتعلم أن...' توضح الدرس الإيجابي المستفاد من القصة."
# Themes and associated vocabulary to enhance stories in Arabic
self.themes = {
"adventure": ["مغامرة", "استكشاف", "اكتشاف", "كنز", "خريطة"],
"fantasy": ["سحر", "تنين", "ساحر", "جنية", "مملكة"],
"animals": ["غابة", "حيوانات أليفة", "حياة برية", "أدغال", "مزرعة"],
"friendship": ["صداقة", "مشاركة", "مساعدة", "تعاون", "فريق"],
"science": ["تجربة", "اختراع", "اكتشاف", "روبوت", "فضاء"],
"humor": ["مرح", "ضحك", "نكتة", "مواقف مضحكة", "فكاهة"],
"mystery": ["لغز", "سر", "تحقيق", "غموض", "حل الألغاز"],
"history": ["تاريخ", "حضارة", "ملك", "قديم", "أثري"],
"arts": ["موسيقى", "فن", "رسم", "رقص", "إبداع"],
"space": ["كواكب", "نجوم", "فضاء", "مجرة", "رواد فضاء"],
"sports": ["كرة", "سباق", "مباراة", "رياضة", "لعبة"],
"heroes": ["بطل", "شجاعة", "إنقاذ", "قوة", "مساعدة"],
"underwater": ["بحر", "محيط", "أسماك", "غواص", "مرجان"],
"nature": ["طبيعة", "حدائق", "جبال", "نباتات", "زهور"],
"technology": ["اختراع", "آلة", "حاسوب", "روبوت", "مستقبل"],
"family": ["عائلة", "أسرة", "والدين", "أطفال", "محبة"],
"school": ["مدرسة", "معلم", "تلميذ", "فصل", "تعلم"],
"seasons": ["فصول", "شتاء", "صيف", "خريف", "ربيع"],
"travel": ["سفر", "رحلة", "بلدان", "مغامرة", "استكشاف"],
"food": ["طعام", "وصفات", "مطبخ", "طهي", "حلويات"]
}
def get_image_description(self, image_bytes):
"""Get a detailed description of the image using OpenRouter's GPT-4.1"""
if not self.openrouter_api_key:
raise ValueError("OpenRouter API key not set")
# Convert image to base64
img_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Create prompt for image description
prompt = "صف هذه الصورة بالتفصيل. ركز على العناصر الرئيسية والأعمال والإعدادات وأي عناصر ملحوظة. قدم وصفًا واضحًا وشاملًا يمكن استخدامه كسياق لكتابة قصة للأطفال. الرجاء الرد باللغة العربية."
try:
# Call the OpenRouter API with GPT-4.1 model
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.openrouter_api_key}",
"HTTP-Referer": self.site_url,
"X-Title": self.site_name
}
payload = {
"model": self.vision_model,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64}"
}
}
]
}
],
"max_tokens": 300
}
response = requests.post(
self.openrouter_api_url,
headers=headers,
json=payload
)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"]
except Exception as e:
raise ValueError(f"Error getting image description: {str(e)}")
def generate_story_with_deepseek(self, description):
"""Generate a story based on the image description using DeepSeek API with random theme"""
if not self.deepseek_api_key:
raise ValueError("DeepSeek API key not set")
# Choose a random theme
theme = random.choice(AVAILABLE_THEMES)
theme_words = "، ".join(self.themes[theme][:3]) # Use first 3 theme words
# Create prompt with the random theme
prompt = f"{self.template.format(description)} ضمّن عناصر {theme} مثل {theme_words}."
# Create the API payload for DeepSeek
payload = {
"model": "deepseek-chat",
"messages": [
{
"role": "system",
"content": "أنت كاتب قصص للأطفال متخصص في كتابة القصص باللغة العربية الفصحى المناسبة للأطفال. مهمتك إنشاء قصص تعليمية وممتعة مستوحاة من الوصف المقدم. تأكد من احترام القيم والثقافة العربية والإسلامية في جميع القصص. يجب أن تنتهي كل قصة بدرس أخلاقي واضح يبدأ بعبارة 'ومن هذه القصة نتعلم أن...' لترسيخ القيم الإيجابية لدى الأطفال."
},
{
"role": "user",
"content": prompt
}
],
"max_tokens": 1000,
"temperature": 0.7
}
# Set the headers with authorization
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.deepseek_api_key}"
}
try:
# Make the API request to DeepSeek
response = requests.post(self.deepseek_api_url, headers=headers, json=payload)
response.raise_for_status()
# Parse the API response
result = response.json()
story = result.get("choices", [{}])[0].get("message", {}).get("content", "")
if not story:
raise ValueError("No story was generated from the DeepSeek API")
return story, theme
except Exception as e:
raise ValueError(f"Error generating story with DeepSeek: {str(e)}")
def generate(self, image_file):
"""
Generate a story based on the input image
Args:
image_file: The uploaded image file
Returns:
str: A generated story for kids ages 6-10
"""
try:
# Process the image
image = Image.open(image_file).convert('RGB')
# Resize image if too large
max_size = 2048
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.LANCZOS)
# Convert image to bytes
buffered = BytesIO()
image.save(buffered, format="JPEG", quality=85)
img_bytes = buffered.getvalue()
# Step 1: Get image description using OpenRouter/GPT-4.1
description = self.get_image_description(img_bytes)
# Step 2: Generate story from description using DeepSeek (with random theme)
story, theme = self.generate_story_with_deepseek(description)
# Format the story with paragraph breaks
formatted_story = self._format_story(story)
return formatted_story, theme
except Exception as e:
raise ValueError(f"Error generating story: {str(e)}")
def _format_story(self, story):
"""Format the story with paragraph breaks and clean special characters"""
# Clean special characters like asterisks, hashtags, etc.
cleaned_story = self._clean_special_characters(story)
# Add paragraph breaks every 2-3 sentences if needed
if "\n\n" not in cleaned_story:
sentences = cleaned_story.split('.')
formatted_text = ""
for i, sentence in enumerate(sentences):
if sentence.strip(): # Skip empty sentences
formatted_text += sentence.strip() + "."
if i % 3 == 2: # Add paragraph break every 3 sentences
formatted_text += "\n\n"
return formatted_text
else:
return cleaned_story # Story already has paragraph breaks
def _clean_special_characters(self, text):
"""Remove special characters that might be in the generated text"""
import re
# Remove markdown formatting characters (* # _ ~)
cleaned = re.sub(r'[*#_~]+', '', text)
# Remove other special characters that might appear in AI-generated text
cleaned = re.sub(r'[`]+', '', cleaned)
# Remove any double spaces created by the cleaning
cleaned = re.sub(r' +', ' ', cleaned)
# Remove any lines that only contain special characters or are empty
lines = cleaned.split('\n')
filtered_lines = []
for line in lines:
if line.strip() and not all(c in '=-+*#_~' for c in line.strip()):
filtered_lines.append(line)
# Rejoin the cleaned lines
return '\n'.join(filtered_lines)
# ============================================================
# APP FUNCTIONS
# ============================================================
# Initialize global components
story_generator = StoryGenerator()
tts_client = None # Will be initialized only when needed
def generate_story(image):
"""
Generate a story from an image using the story generator
Args:
image: The uploaded image
Returns:
str: The generated story with formatted title and lesson
"""
if image is None:
return "Please upload an image first."
# Save the image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp:
image.save(temp.name)
temp_filename = temp.name
try:
# Open the image file and generate the story with random theme
with open(temp_filename, 'rb') as img_file:
story, theme = story_generator.generate(img_file)
# Format the story with title and lesson in separate paragraphs
formatted_story = format_story_output(story)
return formatted_story
except Exception as e:
error_msg = f"Error generating story: {str(e)}"
return error_msg
finally:
# Clean up the temporary file
if os.path.exists(temp_filename):
os.unlink(temp_filename)
def format_story_output(story):
"""
Format the story output with title and lesson in separate paragraphs
and remove all special symbols
Args:
story: The raw story text
Returns:
str: Clean formatted story with title and lesson separated
"""
import re
# First clean all special characters from the text
def clean_text(text):
# Remove markdown formatting characters (* # _ ~) and other symbols
cleaned = re.sub(r'[*#_~`]+', '', text)
return cleaned.strip()
# Split by newlines to process paragraphs
paragraphs = story.split('\n')
# Initialize variables
title = ""
lesson = ""
story_content = []
# Process each paragraph
for i, para in enumerate(paragraphs):
para = clean_text(para)
if not para:
continue
# First non-empty paragraph is likely the title
if not title and i == 0:
title = para
continue
# Check if this paragraph contains the lesson
if "ومن هذه القصة نتعلم أن" in para:
lesson = para
continue
# Otherwise, it's part of the main story
story_content.append(para)
# If no explicit lesson was found, check the last paragraph
if not lesson and story_content:
last_para = story_content[-1]
if any(marker in last_para for marker in ["الدرس", "العبرة", "نتعلم"]):
lesson = last_para
story_content.pop()
# Build the formatted output
formatted_output = ""
# Add title in a separate paragraph (without bold formatting)
if title:
formatted_output += f"{title}\n\n"
# Add main story content
formatted_output += "\n\n".join(story_content)
# Add lesson in its own paragraph at the end
if lesson:
formatted_output += f"\n\n{lesson}"
return formatted_output
def generate_audio(story):
"""
Generate audio from a story using the TTS client with fixed voice and style
With retry logic (max 3 attempts)
Args:
story: The story text to convert to audio
Returns:
Audio file path for Gradio or error message
"""
if not story or story.startswith("Error") or story.startswith("Please upload"):
return None, "Please generate a valid story first."
# No need to clean markdown as it's already clean in the formatted output
clean_story = story
# Initialize TTS client if needed
global tts_client
if tts_client is None:
print("Creating new TTS client")
tts_client = TTSClient()
# Check if client has an error message from initialization
if hasattr(tts_client, 'error_message') and tts_client.error_message:
return None, f"TTS client error: {tts_client.error_message}"
# Retry logic - attempt up to 3 times
max_attempts = 3
attempt = 0
last_error = None
while attempt < max_attempts:
attempt += 1
print(f"TTS generation attempt {attempt} of {max_attempts}")
try:
# Generate audio with fixed voice and style
audio_path_result = tts_client.generate_audio(
text=clean_story
)
# Handle string or tuple result
if isinstance(audio_path_result, tuple):
if len(audio_path_result) >= 2:
if audio_path_result[0] is None:
last_error = audio_path_result[1]
print(f"TTS error on attempt {attempt}: {last_error}")
continue # Try again
else:
# Success case where a tuple with path and message was returned
audio_path = audio_path_result[0]
else:
last_error = "Invalid response from TTS service"
print(f"TTS error on attempt {attempt}: {last_error}")
continue # Try again
else:
audio_path = audio_path_result
# Verify audio path
if not isinstance(audio_path, str):
last_error = f"Invalid audio path type: {type(audio_path)}"
print(f"TTS error on attempt {attempt}: {last_error}")
continue # Try again
if not os.path.exists(audio_path):
last_error = f"Audio file not found at {audio_path}"
print(f"TTS error on attempt {attempt}: {last_error}")
continue # Try again
# Return the audio file path for Gradio to use (success case)
return audio_path, "Audio generated successfully!"
except Exception as e:
last_error = str(e)
print(f"Exception on TTS attempt {attempt}: {last_error}")
# If we get here, all attempts failed
return None, f"TTS failed after {max_attempts} attempts. Last error: {last_error}"
# ============================================================
# GRADIO INTERFACE
# ============================================================
# Create the Gradio interface
with gr.Blocks(title="Rawi Kids Story Generator") as demo:
gr.Markdown("# Rawi Kids Story Generator")
gr.Markdown("Upload an image and get a story for kids with audio narration!")
# Store generated story between function calls
story_state = gr.State("")
# Create the interface
with gr.Column():
gr.Markdown("## Generate Story")
image_input = gr.Image(type="pil", label="Upload Image")
generate_story_btn = gr.Button("Generate Story", variant="primary")
story_output = gr.Textbox(label="Generated Story", lines=10)
gr.Markdown("## Generate Audio Narration")
generate_audio_btn = gr.Button("Listen to Story", variant="secondary")
audio_status = gr.Textbox(label="Audio Status", lines=1)
audio_output = gr.Audio(label="Story Narration", type="filepath")
# Set up the story button click event
generate_story_btn.click(
fn=generate_story,
inputs=[image_input],
outputs=story_output
).then(
fn=lambda s: s, # Store generated story in state
inputs=[story_output],
outputs=story_state
)
# Set up the audio button click event
generate_audio_btn.click(
fn=generate_audio,
inputs=[story_state],
outputs=[audio_output, audio_status]
)
# Launch the app
if __name__ == "__main__":
demo.launch()