zakerytclarke's picture
Update app.py
a83ab40 verified
raw
history blame
5.07 kB
import streamlit as st
import hashlib
import os
import aiohttp
import asyncio
import time
from langsmith import traceable
import random
import discord
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from pydantic import BaseModel
from typing import List, Optional
from tqdm import tqdm
import re
import os
st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
tokenizer = None
model = None
model_name = "teapotai/teapotllm"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def log_time(func):
async def wrapper(*args, **kwargs):
start_time = time.time()
result = await func(*args, **kwargs) # Make it awaitable
end_time = time.time()
print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
return result
return wrapper
API_KEY = os.environ.get("brave_api_key")
@log_time
async def brave_search(query, count=1):
url = "https://api.search.brave.com/res/v1/web/search"
headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
params = {"q": query, "count": count}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, params=params) as response:
if response.status == 200:
results = await response.json()
print(results)
return [(res["title"], res["description"], res["url"]) for res in results.get("web", {}).get("results", [])]
else:
print(f"Error: {response.status}, {await response.text()}")
return []
@traceable
@log_time
async def query_teapot(prompt, context, user_input):
input_text = prompt + "\n" + context + "\n" + user_input
print(input_text)
start_time = time.time()
inputs = tokenizer(input_text, return_tensors="pt")
input_length = inputs["input_ids"].shape[1]
output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
total_length = output.shape[1] # Includes both input and output tokens
output_length = total_length - input_length # Extract output token count
end_time = time.time()
elapsed_time = end_time - start_time
tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf")
return output_text
@log_time
async def handle_chat(user_input):
search_start_time = time.time()
results = await brave_search(user_input)
search_end_time = time.time()
documents = [desc.replace('<strong>', '').replace('</strong>', '') for _, desc, _ in results]
context = "\n".join(documents)
prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization."""
generation_start_time = time.time()
response = await query_teapot(prompt, context, user_input)
generation_end_time = time.time()
debug_info = f"""
Prompt:
{prompt}
Context:
{context}
Search time: {search_end_time - search_start_time:.2f} seconds
Generation time: {generation_end_time - generation_start_time:.2f} seconds
Response: {response}
"""
return response, debug_info
st.write("418 I'm a teapot")
DISCORD_TOKEN = os.environ.get("discord_key")
# Create an instance of Intents and enable the required ones
intents = discord.Intents.default() # Default intents enable basic functionality
intents.messages = True # Enable message-related events
# Create an instance of a client with the intents
client = discord.Client(intents=intents)
# Event when the bot has connected to the server
@client.event
async def on_ready():
print(f'Logged in as {client.user}')
# Event when a message is received
@client.event
async def on_message(message):
# Check if the message is from the bot itself to prevent a loop
if message.author == client.user:
return
# Exit the function if the bot is not mentioned
if f'<@{client.user.id}>' not in message.content:
return
print(message.content)
is_debug = "debug:" in message.content
async with message.channel.typing():
cleaned_message=message.content.replace("debug:", "").replace(f'<@{client.user.id}>',"")
response, debug_info = await handle_chat(cleaned_message)
print(response)
sent_message = await message.reply(response)
# Create a thread from the sent message
if is_debug:
thread = await sent_message.create_thread(name=f"""Debug Thread: '{cleaned_message}'""", auto_archive_duration=60)
# Send a message in the created thread
await thread.send(debug_info)
# Run the bot with your token
client.run(DISCORD_TOKEN)