Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import random | |
| import openai | |
| from openai import APIError, APIConnectionError, RateLimitError | |
| import os | |
| from PIL import Image # This is the corrected import | |
| import io | |
| import base64 | |
| import asyncio | |
| from queue import Queue | |
| from threading import Thread | |
| import time | |
| # Get the current script's directory | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| avatars_dir = os.path.join(current_dir, "avatars") | |
| # Dictionary mapping characters to their avatar image filenames | |
| character_avatars = { | |
| "Harry Potter": "harry.png", | |
| "Hermione Granger": "hermione.png", | |
| "poor Ph.D. student": "phd.png", | |
| "Donald Trump": "trump.png", | |
| "a super cute red panda": "red_panda.png" | |
| } | |
| BACKUP_API_KEY_0 = os.environ.get('BACKUP_API_KEY_0') | |
| BACKUP_API_KEY_1 = os.environ.get('BACKUP_API_KEY_1') | |
| BACKUP_API_KEYS = [BACKUP_API_KEY_0, BACKUP_API_KEY_1] | |
| predefined_characters = ["Harry Potter", "Hermione Granger", "poor Ph.D. student", "Donald Trump", "a super cute red panda"] | |
| def get_character(dropdown_value, custom_value): | |
| return custom_value if dropdown_value == "Custom" else dropdown_value | |
| def resize_image(image_path, size=(100, 100)): | |
| if not os.path.exists(image_path): | |
| return None | |
| with Image.open(image_path) as img: | |
| img.thumbnail(size) | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode() | |
| resized_avatars = {} | |
| for character, filename in character_avatars.items(): | |
| full_path = os.path.join(avatars_dir, filename) | |
| if os.path.exists(full_path): | |
| resized_avatars[character] = resize_image(full_path) | |
| else: | |
| pass | |
| async def generate_response_stream(messages, user_api_key): | |
| # Combine the user's API key with your backup keys | |
| api_keys = [user_api_key] + BACKUP_API_KEYS # backup_api_keys is a list of your internal keys | |
| for idx, api_key in enumerate(api_keys): | |
| client = openai.AsyncOpenAI( | |
| api_key=api_key, | |
| base_url="https://api.sambanova.ai/v1", | |
| ) | |
| try: | |
| response = await client.chat.completions.create( | |
| model='Meta-Llama-3.1-405B-Instruct', | |
| messages=messages, | |
| temperature=0.7, | |
| top_p=0.9, | |
| stream=True | |
| ) | |
| full_response = "" | |
| async for chunk in response: | |
| if chunk.choices[0].delta.content: | |
| full_response += chunk.choices[0].delta.content | |
| yield full_response | |
| # If successful, exit the loop | |
| return | |
| except RateLimitError: | |
| if idx == len(api_keys) - 1: | |
| # No more API keys to try | |
| raise Exception("Rate limit exceeded") | |
| else: | |
| # Try the next API key | |
| continue | |
| except Exception as e: | |
| # For other exceptions, raise the error | |
| raise e | |
| async def simulate_conversation_stream(character1, character2, initial_message, num_turns, api_key): | |
| messages_character_1 = [ | |
| {"role": "system", "content": f"Avoid overly verbose answer in your response. Act as {character1}."}, | |
| {"role": "assistant", "content": initial_message} | |
| ] | |
| messages_character_2 = [ | |
| {"role": "system", "content": f"Avoid overly verbose answer in your response. Act as {character2}."}, | |
| {"role": "user", "content": initial_message} | |
| ] | |
| conversation = [ | |
| {"character": character1, "content": initial_message}, | |
| # We will add new messages as we loop | |
| ] | |
| yield format_conversation_as_html(conversation) | |
| num_turns *= 2 | |
| for turn_num in range(num_turns - 1): | |
| current_character = character2 if turn_num % 2 == 0 else character1 | |
| messages = messages_character_2 if turn_num % 2 == 0 else messages_character_1 | |
| # Add a new empty message for the current character | |
| conversation.append({"character": current_character, "content": ""}) | |
| full_response = "" | |
| try: | |
| async for response in generate_response_stream(messages, api_key): | |
| full_response = response | |
| conversation[-1]["content"] = full_response | |
| yield format_conversation_as_html(conversation) | |
| # After a successful response, update the messages | |
| if turn_num % 2 == 0: | |
| messages_character_1.append({"role": "user", "content": full_response}) | |
| messages_character_2.append({"role": "assistant", "content": full_response}) | |
| else: | |
| messages_character_2.append({"role": "user", "content": full_response}) | |
| messages_character_1.append({"role": "assistant", "content": full_response}) | |
| except Exception as e: | |
| # Replace the current message with the error message | |
| error_message = f"Error: {str(e)}" | |
| conversation[-1]["character"] = "System" | |
| conversation[-1]["content"] = error_message | |
| yield format_conversation_as_html(conversation) | |
| # Stop the conversation | |
| break | |
| def stream_conversation(character1, character2, initial_message, num_turns, api_key, queue): | |
| async def run_simulation(): | |
| try: | |
| async for html in simulate_conversation_stream(character1, character2, initial_message, num_turns, api_key): | |
| queue.put(html) | |
| queue.put(None) # Signal that the conversation is complete | |
| except Exception as e: | |
| # Handle exceptions and put the error message in the queue | |
| error_message = f"Error: {str(e)}" | |
| queue.put(error_message) | |
| queue.put(None) # Signal that the conversation is complete | |
| # Create a new event loop for the thread | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| loop.run_until_complete(run_simulation()) | |
| loop.close() | |
| def validate_api_key(api_key): | |
| if not api_key.strip(): | |
| return False, "API key is required. Please enter a valid API key." | |
| return True, "" | |
| def update_api_key_status(api_key): | |
| is_valid, message = validate_api_key(api_key) | |
| if not is_valid: | |
| return f"<p style='color: red;'>{message}</p>" | |
| return "" | |
| def chat_interface(character1_dropdown, character1_custom, character2_dropdown, character2_custom, | |
| initial_message, num_turns, api_key): | |
| character1 = get_character(character1_dropdown, character1_custom) | |
| character2 = get_character(character2_dropdown, character2_custom) | |
| queue = Queue() | |
| thread = Thread(target=stream_conversation, args=(character1, character2, initial_message, num_turns, api_key, queue)) | |
| thread.start() | |
| while True: | |
| result = queue.get() | |
| if result is None: | |
| break | |
| yield result | |
| thread.join() | |
| def format_conversation_as_html(conversation): | |
| html_output = """ | |
| <style> | |
| .chat-container { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 10px; | |
| font-family: Arial, sans-serif; | |
| } | |
| .message { | |
| display: flex; | |
| padding: 10px; | |
| border-radius: 10px; | |
| max-width: 80%; | |
| align-items: flex-start; | |
| } | |
| .left { | |
| align-self: flex-start; | |
| background-color: #1565C0; | |
| color: #FFFFFF; | |
| } | |
| .right { | |
| align-self: flex-end; | |
| background-color: #2E7D32; | |
| color: #FFFFFF; | |
| flex-direction: row-reverse; | |
| } | |
| .avatar-container { | |
| flex-shrink: 0; | |
| width: 40px; | |
| height: 40px; | |
| margin: 0 10px; | |
| } | |
| .avatar { | |
| width: 100%; | |
| height: 100%; | |
| border-radius: 50%; | |
| object-fit: cover; | |
| } | |
| .message-content { | |
| display: flex; | |
| flex-direction: column; | |
| min-width: 150px; | |
| flex-grow: 1; | |
| } | |
| .character-name { | |
| font-weight: bold; | |
| margin-bottom: 5px; | |
| } | |
| .message-text { | |
| word-wrap: break-word; | |
| overflow-wrap: break-word; | |
| } | |
| </style> | |
| <div class="chat-container"> | |
| """ | |
| for i, message in enumerate(conversation): | |
| align = "left" if i % 2 == 0 else "right" | |
| avatar_data = resized_avatars.get(message["character"]) | |
| html_output += f'<div class="message {align}">' | |
| if avatar_data: | |
| html_output += f''' | |
| <div class="avatar-container"> | |
| <img src="data:image/png;base64,{avatar_data}" class="avatar" alt="{message["character"]} avatar"> | |
| </div> | |
| ''' | |
| html_output += f''' | |
| <div class="message-content"> | |
| <div class="character-name">{message["character"]}</div> | |
| <div class="message-text">{message["content"]}</div> | |
| </div> | |
| </div> | |
| ''' | |
| html_output += "</div>" | |
| return html_output | |
| def format_chat_for_download(html_chat): | |
| # Extract text content from HTML | |
| import re | |
| chat_text = re.findall(r'<div class="character-name">(.*?)</div>.*?<div class="message-text">(.*?)</div>', html_chat, re.DOTALL) | |
| return "\n".join([f"{speaker.strip()}: {message.strip()}" for speaker, message in chat_text]) | |
| def save_chat_to_file(chat_content): | |
| # Create a downloads directory if it doesn't exist | |
| downloads_dir = os.path.join(os.getcwd(), "downloads") | |
| os.makedirs(downloads_dir, exist_ok=True) | |
| # Generate a unique filename | |
| import datetime | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"chat_{timestamp}.txt" | |
| file_path = os.path.join(downloads_dir, filename) | |
| # Save the chat content to the file | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| f.write(chat_content) | |
| return file_path | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Character Chat Generator") | |
| gr.Markdown("Powerd by [LLama3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on [SambaNova Cloud](https://cloud.sambanova.ai/apis)") | |
| api_key = gr.Textbox(label="Enter your Sambanova Cloud API Key\n(To get one, go to https://cloud.sambanova.ai/apis)", type="password") | |
| api_key_status = gr.Markdown() | |
| with gr.Column(): | |
| character1_dropdown = gr.Dropdown(choices=predefined_characters + ["Custom"], label="Select Character 1") | |
| character1_custom = gr.Textbox(label="Custom Character 1 (if selected above)", visible=False) | |
| with gr.Column(): | |
| character2_dropdown = gr.Dropdown(choices=predefined_characters + ["Custom"], label="Select Character 2") | |
| character2_custom = gr.Textbox(label="Custom Character 2 (if selected above)", visible=False) | |
| initial_message = gr.Textbox(label="Initial message (for Character 1)") | |
| num_turns = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of conversation turns") | |
| generate_btn = gr.Button("Generate Conversation") | |
| output = gr.HTML(label="Generated Conversation") | |
| def show_custom_input(choice): | |
| return gr.update(visible=choice == "Custom") | |
| character1_dropdown.change(show_custom_input, inputs=character1_dropdown, outputs=character1_custom) | |
| character2_dropdown.change(show_custom_input, inputs=character2_dropdown, outputs=character2_custom) | |
| api_key.change(update_api_key_status, inputs=[api_key], outputs=[api_key_status]) | |
| generate_btn.click( | |
| chat_interface, | |
| inputs=[character1_dropdown, character1_custom, character2_dropdown, | |
| character2_custom, initial_message, num_turns, api_key], | |
| outputs=output, | |
| ) | |
| gr.Markdown("## Download Chat History") | |
| download_btn = gr.Button("Download Conversation") | |
| download_output = gr.File(label="Download") | |
| def download_conversation(html_chat): | |
| chat_content = format_chat_for_download(html_chat) | |
| file_path = save_chat_to_file(chat_content) | |
| return file_path | |
| download_btn.click( | |
| download_conversation, | |
| inputs=output, | |
| outputs=download_output | |
| ) | |
| app.load(lambda: update_api_key_status(""), outputs=[api_key_status]) | |
| if __name__ == "__main__": | |
| app.launch() |