Spaces:
Running
Running
import streamlit as st | |
import hashlib | |
import os | |
import aiohttp | |
import asyncio | |
import time | |
from langsmith import traceable | |
import random | |
import discord | |
from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from pydantic import BaseModel | |
from typing import List, Optional | |
from tqdm import tqdm | |
import re | |
import os | |
st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide") | |
tokenizer = None | |
model = None | |
model_name = "teapotai/teapotllm" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
def log_time(func): | |
async def wrapper(*args, **kwargs): | |
start_time = time.time() | |
result = await func(*args, **kwargs) # Make it awaitable | |
end_time = time.time() | |
print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds") | |
return result | |
return wrapper | |
API_KEY = os.environ.get("brave_api_key") | |
async def brave_search(query, count=1): | |
url = "https://api.search.brave.com/res/v1/web/search" | |
headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY} | |
params = {"q": query, "count": count} | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url, headers=headers, params=params) as response: | |
if response.status == 200: | |
results = await response.json() | |
print(results) | |
return [(res["title"], res["description"], res["url"]) for res in results.get("web", {}).get("results", [])] | |
else: | |
print(f"Error: {response.status}, {await response.text()}") | |
return [] | |
async def query_teapot(prompt, context, user_input): | |
input_text = prompt + "\n" + context + "\n" + user_input | |
print(input_text) | |
start_time = time.time() | |
inputs = tokenizer(input_text, return_tensors="pt") | |
input_length = inputs["input_ids"].shape[1] | |
output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512) | |
output_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
total_length = output.shape[1] # Includes both input and output tokens | |
output_length = total_length - input_length # Extract output token count | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf") | |
return output_text | |
async def handle_chat(user_input): | |
search_start_time = time.time() | |
results = await brave_search(user_input) | |
search_end_time = time.time() | |
documents = [desc.replace('<strong>', '').replace('</strong>', '') for _, desc, _ in results] | |
context = "\n".join(documents) | |
prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization.""" | |
generation_start_time = time.time() | |
response = await query_teapot(prompt, context, user_input) | |
generation_end_time = time.time() | |
debug_info = f""" | |
Prompt: | |
{prompt} | |
Context: | |
{context} | |
Search time: {search_end_time - search_start_time:.2f} seconds | |
Generation time: {generation_end_time - generation_start_time:.2f} seconds | |
Response: {response} | |
""" | |
return response, debug_info | |
st.write("418 I'm a teapot") | |
DISCORD_TOKEN = os.environ.get("discord_key") | |
# Create an instance of Intents and enable the required ones | |
intents = discord.Intents.default() # Default intents enable basic functionality | |
intents.messages = True # Enable message-related events | |
# Create an instance of a client with the intents | |
client = discord.Client(intents=intents) | |
# Event when the bot has connected to the server | |
async def on_ready(): | |
print(f'Logged in as {client.user}') | |
# Event when a message is received | |
async def on_message(message): | |
# Check if the message is from the bot itself to prevent a loop | |
if message.author == client.user: | |
return | |
# Exit the function if the bot is not mentioned | |
if f'<@{client.user.id}>' not in message.content: | |
return | |
print(message.content) | |
is_debug = "debug:" in message.content | |
async with message.channel.typing(): | |
cleaned_message=message.content.replace("debug:", "").replace(f'<@{client.user.id}>',"") | |
response, debug_info = await handle_chat(cleaned_message) | |
print(response) | |
sent_message = await message.reply(response) | |
# Create a thread from the sent message | |
if is_debug: | |
thread = await sent_message.create_thread(name=f"""Debug Thread: '{cleaned_message}'""", auto_archive_duration=60) | |
# Send a message in the created thread | |
await thread.send(debug_info) | |
# Run the bot with your token | |
client.run(DISCORD_TOKEN) | |