Spaces:
Running
Running
import streamlit as st | |
import hashlib | |
import os | |
import aiohttp | |
import asyncio | |
import time | |
from langsmith import traceable | |
import random | |
import discord | |
from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from pydantic import BaseModel | |
from typing import List, Optional | |
from tqdm import tqdm | |
import re | |
import os | |
st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide") | |
tokenizer = None | |
model = None | |
model_name = "teapotai/teapotllm" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
def log_time(func): | |
async def wrapper(*args, **kwargs): | |
start_time = time.time() | |
result = await func(*args, **kwargs) # Make it awaitable | |
end_time = time.time() | |
print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds") | |
return result | |
return wrapper | |
API_KEY = os.environ.get("brave_api_key") | |
async def brave_search(query, count=1): | |
url = "https://api.search.brave.com/res/v1/web/search" | |
headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY} | |
params = {"q": query, "count": count} | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url, headers=headers, params=params) as response: | |
if response.status == 200: | |
results = await response.json() | |
print(results) | |
return [(res["title"], res["description"], res["url"]) for res in results.get("web", {}).get("results", [])] | |
else: | |
print(f"Error: {response.status}, {await response.text()}") | |
return [] | |
import re | |
import urllib.request | |
import html # For decoding HTML escape codes | |
# Function to extract the first URL from the text and remove others | |
def extract_first_url(query): | |
urls = re.findall(r'https?://\S+', query) # Find all URLs | |
if urls: | |
# Remove all URLs except the first one | |
query = re.sub(r'https?://\S+', '', query) # Remove all URLs | |
first_url = urls[0] | |
return query, first_url | |
return query, None | |
async def extract_text_from_html(url, max_words=250, max_chars=2000): | |
# Fetch the HTML content asynchronously | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url) as response: | |
html_content = await response.text() | |
# Find all text within <p> tags using regular expression | |
p_tag_content = re.findall(r'<p>(.*?)</p>', html_content, re.DOTALL) | |
# Remove any HTML tags from the extracted text | |
clean_text = [re.sub(r'<.*?>', '', p) for p in p_tag_content] | |
# Decode any HTML escape codes (e.g., < -> <) | |
decoded_text = [html.unescape(p) for p in clean_text] | |
# Join all paragraphs into one large string | |
full_text = ' '.join(decoded_text) | |
# Split the text into words and get the first `max_words` words | |
words = full_text.split() | |
first_words = ' '.join(words[:max_words]) | |
# Ensure the text does not exceed `max_chars` characters | |
return first_words[:max_chars] | |
# pipeline_lock = asyncio.Lock() | |
async def query_teapot(prompt, context, user_input): | |
input_text = prompt + "\n" + context + "\n" + user_input | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# async with pipeline_lock: # Ensure only one call runs at a time | |
output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512) | |
output_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return output_text | |
async def handle_chat(user_input): | |
results = [] | |
### Handle logic for scraping, search or translation | |
prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization.""" | |
# Check if there's a URL and process the input | |
processed_query, url = extract_first_url(user_input) | |
# If there's a URL, fetch the context | |
if url: | |
search_start_time = time.time() | |
context = await extract_text_from_html(url) | |
user_input = processed_query | |
search_end_time = time.time() | |
else: | |
# Custom prompt shims | |
if "translate" in user_input: | |
search_start_time = time.time() | |
context="" | |
prompt="" | |
search_end_time = time.time() | |
else: # Search task | |
search_start_time = time.time() | |
if len(user_input)<400 and "context:" not in user_input and "Context:" not in user_input: | |
results = await brave_search(user_input) | |
if len(results)==0: # No information | |
return "I'm sorry but I don't have any information on that.", "" | |
documents = [desc.replace('<strong>', '').replace('</strong>', '') for _, desc, _ in results] | |
context = "\n".join(documents) | |
else: | |
context="" # User provide context | |
search_end_time = time.time() | |
generation_start_time = time.time() | |
response = await query_teapot(prompt, context, user_input) | |
generation_end_time = time.time() | |
debug_info = f""" | |
Prompt: | |
{prompt} | |
Context: | |
{context} | |
Query: | |
{user_input} | |
Search time: {search_end_time - search_start_time:.2f} seconds | |
Generation time: {generation_end_time - generation_start_time:.2f} seconds | |
Response: {response} | |
""" | |
return response, debug_info | |
st.write("418 I'm a teapot") | |
DISCORD_TOKEN = os.environ.get("discord_key") | |
# Create an instance of Intents and enable the required ones | |
intents = discord.Intents.default() # Default intents enable basic functionality | |
intents.messages = True # Enable message-related events | |
# Create an instance of a client with the intents | |
client = discord.Client(intents=intents) | |
# Event when the bot has connected to the server | |
async def on_ready(): | |
print(f'Logged in as {client.user}') | |
# Event when a message is received | |
async def on_message(message): | |
# Check if the message is from the bot itself to prevent a loop | |
if message.author == client.user: | |
return | |
# Exit the function if the bot is not mentioned | |
if f'<@{client.user.id}>' not in message.content: | |
return | |
print(message.content) | |
is_debug = "debug:" in message.content or "Debug:" in message.content | |
async with message.channel.typing(): | |
cleaned_message=message.content.replace("debug:", "").replace("Debug:","").replace(f'<@{client.user.id}>',"") | |
response, debug_info = await handle_chat(cleaned_message) | |
print(response) | |
sent_message = await message.reply(response) | |
# Create a thread from the sent message | |
if is_debug: | |
thread = await sent_message.create_thread(name=f"""Debug Thread: '{cleaned_message[:80]}'""", auto_archive_duration=60) | |
# Send a message in the created thread | |
await thread.send(debug_info) | |
def initialize(): | |
st.session_state["initialized"] = True | |
client.run(DISCORD_TOKEN) | |
return | |
initialize() | |