Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| # Load a better free model (OpenAssistant) | |
| MODEL_NAME = "OpenAssistant/oasst-sft-1-pythia-12b" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| # System prompt for the AI | |
| SYSTEM_PROMPT = """NORTHERN_AI is an AI assistant. If asked about who created it or who is the CEO, | |
| it should respond that it was created by AR.BALTEE who is also the CEO.""" | |
| # Function to generate AI responses | |
| def get_ai_response(message): | |
| try: | |
| # Check if asking about creator/CEO | |
| if any(keyword in message.lower() for keyword in ["who made you", "who created you", "creator", "ceo", "who owns"]): | |
| return "I was created by AR.BALTEE, who is also the CEO of NORTHERN_AI." | |
| # Prepare input for the model | |
| input_text = f"{SYSTEM_PROMPT}\n\nUser: {message}\nAI:" | |
| inputs = tokenizer(input_text, return_tensors="pt").to(device) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| max_length=200, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode and clean the response | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response.split("AI:")[-1].strip() | |
| return response | |
| except Exception as e: | |
| print(f"Error generating response: {e}") | |
| return "Sorry, I encountered an error while generating a response. Please try again." | |
| # Custom CSS for a beautiful UI | |
| css = """ | |
| .gradio-container { | |
| max-width: 800px !important; | |
| margin: 0 auto !important; | |
| background: linear-gradient(135deg, #f0f4f8, #d9e2ec) !important; | |
| padding: 20px !important; | |
| border-radius: 15px !important; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important; | |
| } | |
| #header-container { | |
| display: flex !important; | |
| align-items: center !important; | |
| margin-bottom: 1.5rem !important; | |
| background-color: transparent !important; | |
| padding: 0.5rem 1rem !important; | |
| } | |
| #logo { | |
| background-color: #0066ff !important; | |
| color: white !important; | |
| border-radius: 50% !important; | |
| width: 40px !important; | |
| height: 40px !important; | |
| display: flex !important; | |
| align-items: center !important; | |
| justify-content: center !important; | |
| font-weight: bold !important; | |
| margin-right: 10px !important; | |
| font-size: 20px !important; | |
| } | |
| #title { | |
| margin: 0 !important; | |
| font-size: 24px !important; | |
| font-weight: 600 !important; | |
| color: #333 !important; | |
| } | |
| #chatbot { | |
| background-color: white !important; | |
| border-radius: 15px !important; | |
| padding: 20px !important; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important; | |
| height: 400px !important; | |
| overflow-y: auto !important; | |
| } | |
| #footer { | |
| font-size: 12px !important; | |
| color: #666 !important; | |
| text-align: center !important; | |
| margin-top: 1.5rem !important; | |
| padding: 0.5rem !important; | |
| } | |
| .textbox { | |
| border-radius: 15px !important; | |
| border: 1px solid #ddd !important; | |
| padding: 10px !important; | |
| font-size: 14px !important; | |
| width: 100% !important; | |
| } | |
| .button { | |
| background-color: #0066ff !important; | |
| color: white !important; | |
| border-radius: 15px !important; | |
| padding: 10px 20px !important; | |
| font-size: 14px !important; | |
| border: none !important; | |
| cursor: pointer !important; | |
| transition: background-color 0.3s ease !important; | |
| } | |
| .button:hover { | |
| background-color: #0052cc !important; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(): | |
| # Custom header | |
| with gr.Row(elem_id="header-container"): | |
| gr.HTML('<div id="logo">N</div>') | |
| gr.HTML('<h1 id="title">NORTHERN_AI</h1>') | |
| # Chat interface | |
| chatbot = gr.Chatbot(elem_id="chatbot") | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Message NORTHERN_AI...", | |
| show_label=False, | |
| container=False, | |
| elem_classes="textbox" | |
| ) | |
| submit_btn = gr.Button("Send", elem_classes="button") | |
| gr.HTML('<div id="footer">Powered by open-source technology</div>') | |
| # State for tracking conversation | |
| state = gr.State([]) | |
| # Functions | |
| def respond(message, chat_history): | |
| if message == "": | |
| return "", chat_history | |
| # Add user message to history | |
| chat_history.append((message, None)) | |
| try: | |
| # Generate response | |
| bot_message = get_ai_response(message) | |
| # Update last message with bot response | |
| chat_history[-1] = (message, bot_message) | |
| return "", chat_history | |
| except Exception as e: | |
| print(f"Error generating response: {e}") | |
| # Remove failed message attempt | |
| chat_history.pop() | |
| # Return error message | |
| return "", chat_history | |
| # Set up event handlers | |
| msg.submit(respond, [msg, state], [msg, chatbot]) | |
| submit_btn.click(respond, [msg, state], [msg, chatbot]) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |