import streamlit as st import hashlib import os import aiohttp import asyncio import time from langsmith import traceable import random import discord from transformers import pipeline from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import numpy as np from sklearn.metrics.pairwise import cosine_similarity from pydantic import BaseModel from typing import List, Optional from tqdm import tqdm import re import os st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide") tokenizer = None model = None model_name = "teapotai/teapotllm" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) def log_time(func): async def wrapper(*args, **kwargs): start_time = time.time() result = await func(*args, **kwargs) # Make it awaitable end_time = time.time() print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds") return result return wrapper API_KEY = os.environ.get("brave_api_key") @log_time async def brave_search(query, count=1): url = "https://api.search.brave.com/res/v1/web/search" headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY} params = {"q": query, "count": count} async with aiohttp.ClientSession() as session: async with session.get(url, headers=headers, params=params) as response: if response.status == 200: results = await response.json() print(results) return [(res["title"], res["description"], res["url"]) for res in results.get("web", {}).get("results", [])] else: print(f"Error: {response.status}, {await response.text()}") return [] import re import urllib.request import html # For decoding HTML escape codes # Function to extract the first URL from the text and remove others def extract_first_url(query): urls = re.findall(r'https?://\S+', query) # Find all URLs if urls: # Remove all URLs except the first one query = re.sub(r'https?://\S+', '', query) # Remove all URLs first_url = urls[0] return query, first_url return query, None async def extract_text_from_html(url, max_words=250, max_chars=2000): # Fetch the HTML content asynchronously async with aiohttp.ClientSession() as session: async with session.get(url) as response: html_content = await response.text() # Find all text within

tags using regular expression p_tag_content = re.findall(r'

(.*?)

', html_content, re.DOTALL) # Remove any HTML tags from the extracted text clean_text = [re.sub(r'<.*?>', '', p) for p in p_tag_content] # Decode any HTML escape codes (e.g., < -> <) decoded_text = [html.unescape(p) for p in clean_text] # Join all paragraphs into one large string full_text = ' '.join(decoded_text) # Split the text into words and get the first `max_words` words words = full_text.split() first_words = ' '.join(words[:max_words]) # Ensure the text does not exceed `max_chars` characters return first_words[:max_chars] # pipeline_lock = asyncio.Lock() @traceable @log_time async def query_teapot(prompt, context, user_input): input_text = prompt + "\n" + context + "\n" + user_input inputs = tokenizer(input_text, return_tensors="pt") # async with pipeline_lock: # Ensure only one call runs at a time output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512) output_text = tokenizer.decode(output[0], skip_special_tokens=True) return output_text @log_time async def handle_chat(user_input): results = [] ### Handle logic for scraping, search or translation prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization.""" # Check if there's a URL and process the input processed_query, url = extract_first_url(user_input) # If there's a URL, fetch the context if url: search_start_time = time.time() context = await extract_text_from_html(url) user_input = processed_query search_end_time = time.time() else: # Custom prompt shims if "translate" in user_input: search_start_time = time.time() context="" prompt="" search_end_time = time.time() else: # Search task search_start_time = time.time() if len(user_input)<400 and "context:" not in user_input and "Context:" not in user_input: results = await brave_search(user_input) if len(results)==0: # No information return "I'm sorry but I don't have any information on that.", "" documents = [desc.replace('', '').replace('', '') for _, desc, _ in results] context = "\n".join(documents) else: context="" # User provide context search_end_time = time.time() generation_start_time = time.time() response = await query_teapot(prompt, context, user_input) generation_end_time = time.time() debug_info = f""" Prompt: {prompt} Context: {context} Query: {user_input} Search time: {search_end_time - search_start_time:.2f} seconds Generation time: {generation_end_time - generation_start_time:.2f} seconds Response: {response} """ return response, debug_info st.write("418 I'm a teapot") DISCORD_TOKEN = os.environ.get("discord_key") # Create an instance of Intents and enable the required ones intents = discord.Intents.default() # Default intents enable basic functionality intents.messages = True # Enable message-related events # Create an instance of a client with the intents client = discord.Client(intents=intents) # Event when the bot has connected to the server @client.event async def on_ready(): print(f'Logged in as {client.user}') # Event when a message is received @client.event async def on_message(message): # Check if the message is from the bot itself to prevent a loop if message.author == client.user: return # Exit the function if the bot is not mentioned if f'<@{client.user.id}>' not in message.content: return print(message.content) is_debug = "debug:" in message.content or "Debug:" in message.content async with message.channel.typing(): cleaned_message=message.content.replace("debug:", "").replace("Debug:","").replace(f'<@{client.user.id}>',"") response, debug_info = await handle_chat(cleaned_message) print(response) sent_message = await message.reply(response) # Create a thread from the sent message if is_debug: thread = await sent_message.create_thread(name=f"""Debug Thread: '{cleaned_message[:80]}'""", auto_archive_duration=60) # Send a message in the created thread await thread.send(debug_info) @st.cache_resource def initialize(): st.session_state["initialized"] = True client.run(DISCORD_TOKEN) return initialize()