|
import os |
|
import sys |
|
import base64 |
|
import random |
|
import requests |
|
import tempfile |
|
import gradio as gr |
|
import numpy as np |
|
from io import BytesIO |
|
from PIL import Image |
|
from dotenv import load_dotenv |
|
from gradio_client import Client |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
required_keys = { |
|
"OPENROUTER_API_KEY": "OpenRouter for image recognition", |
|
"DEEPSEEK_API_KEY": "DeepSeek for story generation", |
|
"HF_ACCESS_TOKEN": "HuggingFace for text-to-speech" |
|
} |
|
|
|
missing_keys = [] |
|
for key, purpose in required_keys.items(): |
|
if not os.getenv(key): |
|
missing_keys.append(f"{key} ({purpose})") |
|
|
|
if missing_keys: |
|
print("ERROR: The following required API keys were not found:") |
|
for key in missing_keys: |
|
print(f" - {key}") |
|
print("\nPlease set these environment variables in a .env file or directly.") |
|
sys.exit(1) |
|
|
|
|
|
AVAILABLE_THEMES = [ |
|
"adventure", "fantasy", "animals", "friendship", "science", |
|
"humor", "history", "arts", "space", |
|
"sports", "heroes", "underwater", "nature", "technology", |
|
"family", "school", "seasons", "travel", "food" |
|
] |
|
|
|
|
|
VOICES = ["nova", "alloy", "echo", "fable", "onyx", "shimmer"] |
|
EMOTIONS = ["neutral", "happy", "sad", "anger", "story!!"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STORYTELLING_STYLE = "Affect: A gentle, curious narrator with a British accent, guiding a magical, child-friendly adventure through a fairy tale world. Tone: Magical, warm, and inviting, creating a sense of wonder and excitement for young listeners. Pacing: Steady and measured, with slight pauses to emphasize magical moments and maintain the storytelling flow. Emotion: Wonder, curiosity, and a sense of adventure, with a lighthearted and positive vibe throughout. Pronunciation: Clear and precise, with an emphasis on storytelling, ensuring the words are easy to follow and enchanting to listen to." |
|
|
|
class TTSClient: |
|
"""Client for the private Hugging Face text-to-speech service""" |
|
|
|
def __init__(self): |
|
"""Initialize the TTS client with the HuggingFace access token""" |
|
self.hf_token = os.getenv("HF_ACCESS_TOKEN") |
|
if not self.hf_token: |
|
self.error_message = "HF_ACCESS_TOKEN not found in environment variables" |
|
print(self.error_message) |
|
return |
|
|
|
self.password = "YSF9580" |
|
self.space_id = "KindSynapse/Youssef-Ahmed-Private-Text-To-Speech-Unlimited" |
|
|
|
|
|
self.voice = "shimmer" |
|
self.storytelling_style = STORYTELLING_STYLE |
|
|
|
|
|
self.client = None |
|
self.error_message = None |
|
|
|
def _initialize_client(self): |
|
"""Initialize the client when needed""" |
|
if self.client is None: |
|
try: |
|
print(f"Initializing TTS client for space: {self.space_id}") |
|
self.client = Client( |
|
src=self.space_id, |
|
hf_token=self.hf_token |
|
) |
|
return True, "Client initialized successfully" |
|
except Exception as e: |
|
error_msg = str(e) |
|
self.error_message = f"Failed to initialize TTS client: {error_msg}" |
|
print(self.error_message) |
|
if "SecurityError" in error_msg or "cross-origin" in error_msg: |
|
return False, "Cross-origin error detected" |
|
else: |
|
return False, self.error_message |
|
return True, "Client already initialized" |
|
|
|
def generate_audio(self, text, use_random_seed=True, seed=12345): |
|
""" |
|
Generate audio from text using the private Hugging Face TTS service with fixed voice and style |
|
|
|
Args: |
|
text (str): The text to convert to speech |
|
|
|
Returns: |
|
str or tuple: Path to the generated audio file if successful, or (None, error_message) if there's an error |
|
""" |
|
|
|
client_ok, client_msg = self._initialize_client() |
|
if not client_ok: |
|
return None, client_msg |
|
|
|
if self.client is None: |
|
return None, "TTS client could not be initialized" |
|
|
|
|
|
try: |
|
result = self.client.predict( |
|
password=self.password, |
|
prompt=text, |
|
voice=self.voice, |
|
emotion=self.storytelling_style, |
|
use_random_seed=use_random_seed, |
|
specific_seed=seed, |
|
api_name="/text_to_speech_app" |
|
) |
|
|
|
|
|
if isinstance(result, tuple): |
|
if len(result) >= 1: |
|
if isinstance(result[0], str): |
|
file_path = result[0] |
|
if os.path.exists(file_path): |
|
return file_path |
|
if len(result) >= 2 and result[0] is None: |
|
return None, result[1] |
|
return None, "Invalid response from TTS service" |
|
elif isinstance(result, str): |
|
if os.path.exists(result): |
|
return result |
|
return None, f"Audio file not found at {result}" |
|
else: |
|
return None, f"Unexpected result type: {type(result)}" |
|
|
|
except Exception as e: |
|
return None, f"Error generating audio: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
class StoryGenerator: |
|
def __init__(self): |
|
"""Initialize the story generator with APIs""" |
|
|
|
self.openrouter_api_key = os.getenv("OPENROUTER_API_KEY") |
|
self.deepseek_api_key = os.getenv("DEEPSEEK_API_KEY") |
|
|
|
|
|
self.openrouter_api_url = "https://openrouter.ai/api/v1/chat/completions" |
|
self.deepseek_api_url = "https://api.deepseek.com/v1/chat/completions" |
|
|
|
|
|
self.site_url = os.getenv("SITE_URL", "https://rawi-kids-stories.com") |
|
self.site_name = os.getenv("SITE_NAME", "Rawi Kids Stories") |
|
|
|
|
|
self.vision_model = "openai/gpt-4.1" |
|
|
|
|
|
self.template = "اكتب قصة بسيطة وممتعة لطفل يتراوح عمره بين 6-10 سنوات عن: {}. حافظ على أن تكون القصة حوالي 300-400 كلمة بلغة بسيطة ورسالة إيجابية. اكتب القصة باللغة العربية الفصحى المناسبة للأطفال. تأكد من أن القصة تحترم الثقافة والأخلاق العربية والإسلامية. في نهاية القصة، اختم بعبارة واضحة تبدأ بـ 'ومن هذه القصة نتعلم أن...' توضح الدرس الإيجابي المستفاد من القصة." |
|
|
|
|
|
self.themes = { |
|
"adventure": ["مغامرة", "استكشاف", "اكتشاف", "كنز", "خريطة"], |
|
"fantasy": ["سحر", "تنين", "ساحر", "جنية", "مملكة"], |
|
"animals": ["غابة", "حيوانات أليفة", "حياة برية", "أدغال", "مزرعة"], |
|
"friendship": ["صداقة", "مشاركة", "مساعدة", "تعاون", "فريق"], |
|
"science": ["تجربة", "اختراع", "اكتشاف", "روبوت", "فضاء"], |
|
"humor": ["مرح", "ضحك", "نكتة", "مواقف مضحكة", "فكاهة"], |
|
"mystery": ["لغز", "سر", "تحقيق", "غموض", "حل الألغاز"], |
|
"history": ["تاريخ", "حضارة", "ملك", "قديم", "أثري"], |
|
"arts": ["موسيقى", "فن", "رسم", "رقص", "إبداع"], |
|
"space": ["كواكب", "نجوم", "فضاء", "مجرة", "رواد فضاء"], |
|
"sports": ["كرة", "سباق", "مباراة", "رياضة", "لعبة"], |
|
"heroes": ["بطل", "شجاعة", "إنقاذ", "قوة", "مساعدة"], |
|
"underwater": ["بحر", "محيط", "أسماك", "غواص", "مرجان"], |
|
"nature": ["طبيعة", "حدائق", "جبال", "نباتات", "زهور"], |
|
"technology": ["اختراع", "آلة", "حاسوب", "روبوت", "مستقبل"], |
|
"family": ["عائلة", "أسرة", "والدين", "أطفال", "محبة"], |
|
"school": ["مدرسة", "معلم", "تلميذ", "فصل", "تعلم"], |
|
"seasons": ["فصول", "شتاء", "صيف", "خريف", "ربيع"], |
|
"travel": ["سفر", "رحلة", "بلدان", "مغامرة", "استكشاف"], |
|
"food": ["طعام", "وصفات", "مطبخ", "طهي", "حلويات"] |
|
} |
|
|
|
def get_image_description(self, image_bytes): |
|
"""Get a detailed description of the image using OpenRouter's GPT-4.1""" |
|
if not self.openrouter_api_key: |
|
raise ValueError("OpenRouter API key not set") |
|
|
|
|
|
img_base64 = base64.b64encode(image_bytes).decode('utf-8') |
|
|
|
|
|
prompt = "صف هذه الصورة بالتفصيل. ركز على العناصر الرئيسية والأعمال والإعدادات وأي عناصر ملحوظة. قدم وصفًا واضحًا وشاملًا يمكن استخدامه كسياق لكتابة قصة للأطفال. الرجاء الرد باللغة العربية." |
|
|
|
try: |
|
|
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {self.openrouter_api_key}", |
|
"HTTP-Referer": self.site_url, |
|
"X-Title": self.site_name |
|
} |
|
|
|
payload = { |
|
"model": self.vision_model, |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": prompt |
|
}, |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/jpeg;base64,{img_base64}" |
|
} |
|
} |
|
] |
|
} |
|
], |
|
"max_tokens": 300 |
|
} |
|
|
|
response = requests.post( |
|
self.openrouter_api_url, |
|
headers=headers, |
|
json=payload |
|
) |
|
response.raise_for_status() |
|
result = response.json() |
|
|
|
return result["choices"][0]["message"]["content"] |
|
|
|
except Exception as e: |
|
raise ValueError(f"Error getting image description: {str(e)}") |
|
|
|
def generate_story_with_deepseek(self, description): |
|
"""Generate a story based on the image description using DeepSeek API with random theme""" |
|
if not self.deepseek_api_key: |
|
raise ValueError("DeepSeek API key not set") |
|
|
|
|
|
theme = random.choice(AVAILABLE_THEMES) |
|
theme_words = "، ".join(self.themes[theme][:3]) |
|
|
|
|
|
prompt = f"{self.template.format(description)} ضمّن عناصر {theme} مثل {theme_words}." |
|
|
|
|
|
payload = { |
|
"model": "deepseek-chat", |
|
"messages": [ |
|
{ |
|
"role": "system", |
|
"content": "أنت كاتب قصص للأطفال متخصص في كتابة القصص باللغة العربية الفصحى المناسبة للأطفال. مهمتك إنشاء قصص تعليمية وممتعة مستوحاة من الوصف المقدم. تأكد من احترام القيم والثقافة العربية والإسلامية في جميع القصص. يجب أن تنتهي كل قصة بدرس أخلاقي واضح يبدأ بعبارة 'ومن هذه القصة نتعلم أن...' لترسيخ القيم الإيجابية لدى الأطفال." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
], |
|
"max_tokens": 1000, |
|
"temperature": 0.7 |
|
} |
|
|
|
|
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {self.deepseek_api_key}" |
|
} |
|
|
|
try: |
|
|
|
response = requests.post(self.deepseek_api_url, headers=headers, json=payload) |
|
response.raise_for_status() |
|
|
|
|
|
result = response.json() |
|
story = result.get("choices", [{}])[0].get("message", {}).get("content", "") |
|
|
|
if not story: |
|
raise ValueError("No story was generated from the DeepSeek API") |
|
|
|
return story, theme |
|
|
|
except Exception as e: |
|
raise ValueError(f"Error generating story with DeepSeek: {str(e)}") |
|
|
|
def generate(self, image_file): |
|
""" |
|
Generate a story based on the input image |
|
|
|
Args: |
|
image_file: The uploaded image file |
|
|
|
Returns: |
|
str: A generated story for kids ages 6-10 |
|
""" |
|
try: |
|
|
|
image = Image.open(image_file).convert('RGB') |
|
|
|
|
|
max_size = 2048 |
|
if max(image.size) > max_size: |
|
ratio = max_size / max(image.size) |
|
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) |
|
image = image.resize(new_size, Image.LANCZOS) |
|
|
|
|
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG", quality=85) |
|
img_bytes = buffered.getvalue() |
|
|
|
|
|
description = self.get_image_description(img_bytes) |
|
|
|
|
|
story, theme = self.generate_story_with_deepseek(description) |
|
|
|
|
|
formatted_story = self._format_story(story) |
|
|
|
return formatted_story, theme |
|
|
|
except Exception as e: |
|
raise ValueError(f"Error generating story: {str(e)}") |
|
|
|
def _format_story(self, story): |
|
"""Format the story with paragraph breaks and clean special characters""" |
|
|
|
cleaned_story = self._clean_special_characters(story) |
|
|
|
|
|
if "\n\n" not in cleaned_story: |
|
sentences = cleaned_story.split('.') |
|
formatted_text = "" |
|
|
|
for i, sentence in enumerate(sentences): |
|
if sentence.strip(): |
|
formatted_text += sentence.strip() + "." |
|
if i % 3 == 2: |
|
formatted_text += "\n\n" |
|
|
|
return formatted_text |
|
else: |
|
return cleaned_story |
|
|
|
def _clean_special_characters(self, text): |
|
"""Remove special characters that might be in the generated text""" |
|
import re |
|
|
|
|
|
cleaned = re.sub(r'[*#_~]+', '', text) |
|
|
|
|
|
cleaned = re.sub(r'[`]+', '', cleaned) |
|
|
|
|
|
cleaned = re.sub(r' +', ' ', cleaned) |
|
|
|
|
|
lines = cleaned.split('\n') |
|
filtered_lines = [] |
|
for line in lines: |
|
if line.strip() and not all(c in '=-+*#_~' for c in line.strip()): |
|
filtered_lines.append(line) |
|
|
|
|
|
return '\n'.join(filtered_lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
story_generator = StoryGenerator() |
|
tts_client = None |
|
|
|
def generate_story(image): |
|
""" |
|
Generate a story from an image using the story generator |
|
|
|
Args: |
|
image: The uploaded image |
|
|
|
Returns: |
|
str: The generated story with formatted title and lesson |
|
""" |
|
if image is None: |
|
return "Please upload an image first." |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp: |
|
image.save(temp.name) |
|
temp_filename = temp.name |
|
|
|
try: |
|
|
|
with open(temp_filename, 'rb') as img_file: |
|
story, theme = story_generator.generate(img_file) |
|
|
|
|
|
formatted_story = format_story_output(story) |
|
|
|
return formatted_story |
|
except Exception as e: |
|
error_msg = f"Error generating story: {str(e)}" |
|
return error_msg |
|
finally: |
|
|
|
if os.path.exists(temp_filename): |
|
os.unlink(temp_filename) |
|
|
|
def format_story_output(story): |
|
""" |
|
Format the story output with title and lesson in separate paragraphs |
|
and remove all special symbols |
|
|
|
Args: |
|
story: The raw story text |
|
|
|
Returns: |
|
str: Clean formatted story with title and lesson separated |
|
""" |
|
import re |
|
|
|
|
|
def clean_text(text): |
|
|
|
cleaned = re.sub(r'[*#_~`]+', '', text) |
|
return cleaned.strip() |
|
|
|
|
|
paragraphs = story.split('\n') |
|
|
|
|
|
title = "" |
|
lesson = "" |
|
story_content = [] |
|
|
|
|
|
for i, para in enumerate(paragraphs): |
|
para = clean_text(para) |
|
if not para: |
|
continue |
|
|
|
|
|
if not title and i == 0: |
|
title = para |
|
continue |
|
|
|
|
|
if "ومن هذه القصة نتعلم أن" in para: |
|
lesson = para |
|
continue |
|
|
|
|
|
story_content.append(para) |
|
|
|
|
|
if not lesson and story_content: |
|
last_para = story_content[-1] |
|
if any(marker in last_para for marker in ["الدرس", "العبرة", "نتعلم"]): |
|
lesson = last_para |
|
story_content.pop() |
|
|
|
|
|
formatted_output = "" |
|
|
|
|
|
if title: |
|
formatted_output += f"{title}\n\n" |
|
|
|
|
|
formatted_output += "\n\n".join(story_content) |
|
|
|
|
|
if lesson: |
|
formatted_output += f"\n\n{lesson}" |
|
|
|
return formatted_output |
|
|
|
def generate_audio(story): |
|
""" |
|
Generate audio from a story using the TTS client with fixed voice and style |
|
With retry logic (max 3 attempts) |
|
|
|
Args: |
|
story: The story text to convert to audio |
|
|
|
Returns: |
|
Audio file path for Gradio or error message |
|
""" |
|
if not story or story.startswith("Error") or story.startswith("Please upload"): |
|
return None, "Please generate a valid story first." |
|
|
|
|
|
clean_story = story |
|
|
|
|
|
global tts_client |
|
if tts_client is None: |
|
print("Creating new TTS client") |
|
tts_client = TTSClient() |
|
|
|
|
|
if hasattr(tts_client, 'error_message') and tts_client.error_message: |
|
return None, f"TTS client error: {tts_client.error_message}" |
|
|
|
|
|
max_attempts = 3 |
|
attempt = 0 |
|
last_error = None |
|
|
|
while attempt < max_attempts: |
|
attempt += 1 |
|
print(f"TTS generation attempt {attempt} of {max_attempts}") |
|
|
|
try: |
|
|
|
audio_path_result = tts_client.generate_audio( |
|
text=clean_story |
|
) |
|
|
|
|
|
if isinstance(audio_path_result, tuple): |
|
if len(audio_path_result) >= 2: |
|
if audio_path_result[0] is None: |
|
last_error = audio_path_result[1] |
|
print(f"TTS error on attempt {attempt}: {last_error}") |
|
continue |
|
else: |
|
|
|
audio_path = audio_path_result[0] |
|
else: |
|
last_error = "Invalid response from TTS service" |
|
print(f"TTS error on attempt {attempt}: {last_error}") |
|
continue |
|
else: |
|
audio_path = audio_path_result |
|
|
|
|
|
if not isinstance(audio_path, str): |
|
last_error = f"Invalid audio path type: {type(audio_path)}" |
|
print(f"TTS error on attempt {attempt}: {last_error}") |
|
continue |
|
|
|
if not os.path.exists(audio_path): |
|
last_error = f"Audio file not found at {audio_path}" |
|
print(f"TTS error on attempt {attempt}: {last_error}") |
|
continue |
|
|
|
|
|
return audio_path, "Audio generated successfully!" |
|
|
|
except Exception as e: |
|
last_error = str(e) |
|
print(f"Exception on TTS attempt {attempt}: {last_error}") |
|
|
|
|
|
return None, f"TTS failed after {max_attempts} attempts. Last error: {last_error}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Rawi Kids Story Generator") as demo: |
|
gr.Markdown("# Rawi Kids Story Generator") |
|
gr.Markdown("Upload an image and get a story for kids with audio narration!") |
|
|
|
|
|
story_state = gr.State("") |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("## Generate Story") |
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
generate_story_btn = gr.Button("Generate Story", variant="primary") |
|
story_output = gr.Textbox(label="Generated Story", lines=10) |
|
|
|
gr.Markdown("## Generate Audio Narration") |
|
generate_audio_btn = gr.Button("Listen to Story", variant="secondary") |
|
audio_status = gr.Textbox(label="Audio Status", lines=1) |
|
audio_output = gr.Audio(label="Story Narration", type="filepath") |
|
|
|
|
|
generate_story_btn.click( |
|
fn=generate_story, |
|
inputs=[image_input], |
|
outputs=story_output |
|
).then( |
|
fn=lambda s: s, |
|
inputs=[story_output], |
|
outputs=story_state |
|
) |
|
|
|
|
|
generate_audio_btn.click( |
|
fn=generate_audio, |
|
inputs=[story_state], |
|
outputs=[audio_output, audio_status] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |