File size: 7,357 Bytes
17dfeca
 
 
e2b19bb
 
17dfeca
 
 
fbb941f
17dfeca
 
 
 
 
 
 
 
00915d5
17dfeca
 
00915d5
 
 
 
 
17dfeca
 
 
e2b19bb
17dfeca
00915d5
17dfeca
 
 
 
 
00915d5
17dfeca
 
 
79e2fd1
17dfeca
 
 
e2b19bb
 
 
 
 
 
 
 
 
 
 
d97238d
 
 
6ebe85c
 
 
 
 
 
 
 
 
 
 
 
 
 
207d7c6
 
 
 
 
6ebe85c
 
 
 
 
 
 
 
 
 
 
 
 
207d7c6
6ebe85c
 
 
207d7c6
 
6ebe85c
 
 
 
d9a8caa
d97238d
90ef2fa
 
 
 
 
 
d9a8caa
 
 
 
17dfeca
d97238d
17dfeca
e2b19bb
00915d5
506ca16
6ebe85c
79e2fd1
6ebe85c
 
2ce1709
bc74d1a
6ebe85c
 
ba6b29d
207d7c6
6ebe85c
ba6b29d
6ebe85c
c281301
 
ba6b29d
c281301
 
ba6b29d
 
bc74d1a
c281301
 
eae27f8
 
 
 
 
 
 
 
c281301
 
00915d5
44b3e03
00915d5
29c9904
 
 
 
 
 
 
17dfeca
bddccb9
 
 
00915d5
 
29c9904
 
00915d5
29c9904
17dfeca
00915d5
17dfeca
fbb941f
f958b60
fbb941f
 
00915d5
 
fbb941f
 
 
 
00915d5
fbb941f
 
 
 
00915d5
fbb941f
 
00915d5
fbb941f
 
ad2ccd3
00915d5
ad2ccd3
 
00915d5
66ad49d
00915d5
eae27f8
fbc4c18
eae27f8
a83ab40
fbc4c18
3947bbb
29c9904
00915d5
fbc4c18
8e6d75f
00915d5
 
fbc4c18
29c9904
00915d5
90ef2fa
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import streamlit as st
import hashlib
import os
import aiohttp
import asyncio
import time
from langsmith import traceable
import random
import discord
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from pydantic import BaseModel
from typing import List, Optional
from tqdm import tqdm
import re
import os

st.set_page_config(page_title="TeapotAI Discord Bot", page_icon=":robot_face:", layout="wide")
tokenizer = None
model = None
model_name = "teapotai/teapotllm"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)


def log_time(func):
    async def wrapper(*args, **kwargs):
        start_time = time.time()
        result = await func(*args, **kwargs)  # Make it awaitable
        end_time = time.time()
        print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds")
        return result
    return wrapper


API_KEY = os.environ.get("brave_api_key")

@log_time
async def brave_search(query, count=1):
    url = "https://api.search.brave.com/res/v1/web/search"
    headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY}
    params = {"q": query, "count": count}

    async with aiohttp.ClientSession() as session:
        async with session.get(url, headers=headers, params=params) as response:
            if response.status == 200:
                results = await response.json()
                print(results)
                return [(res["title"], res["description"], res["url"]) for res in results.get("web", {}).get("results", [])]
            else:
                print(f"Error: {response.status}, {await response.text()}")
                return []




import re
import urllib.request
import html  # For decoding HTML escape codes

# Function to extract the first URL from the text and remove others
def extract_first_url(query):
    urls = re.findall(r'https?://\S+', query)  # Find all URLs
    if urls:
        # Remove all URLs except the first one
        query = re.sub(r'https?://\S+', '', query)  # Remove all URLs
        first_url = urls[0]
        return query, first_url
    return query, None

async def extract_text_from_html(url, max_words=250, max_chars=2000):
    # Fetch the HTML content asynchronously
    async with aiohttp.ClientSession() as session:
        async with session.get(url) as response:
            html_content = await response.text()

    # Find all text within <p> tags using regular expression
    p_tag_content = re.findall(r'<p>(.*?)</p>', html_content, re.DOTALL)

    # Remove any HTML tags from the extracted text
    clean_text = [re.sub(r'<.*?>', '', p) for p in p_tag_content]

    # Decode any HTML escape codes (e.g., &lt; -> <)
    decoded_text = [html.unescape(p) for p in clean_text]

    # Join all paragraphs into one large string
    full_text = ' '.join(decoded_text)

    # Split the text into words and get the first `max_words` words
    words = full_text.split()
    first_words = ' '.join(words[:max_words])

    # Ensure the text does not exceed `max_chars` characters
    return first_words[:max_chars]




# pipeline_lock = asyncio.Lock()

@traceable
@log_time
async def query_teapot(prompt, context, user_input):
    input_text = prompt + "\n" + context + "\n" + user_input
    inputs = tokenizer(input_text, return_tensors="pt")
    
    # async with pipeline_lock:  # Ensure only one call runs at a time
    output = await asyncio.to_thread(model.generate, **inputs, max_new_tokens=512)
    output_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return output_text


@log_time
async def handle_chat(user_input):

    results = []
    ### Handle logic for scraping, search or translation
    prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization."""
        
    # Check if there's a URL and process the input
    processed_query, url = extract_first_url(user_input)
    
    # If there's a URL, fetch the context
    if url:
        search_start_time = time.time()
        context = await extract_text_from_html(url)
        user_input = processed_query
        search_end_time = time.time()
    else:
        # Custom prompt shims
        if "translate" in user_input:
            search_start_time = time.time()
            context=""
            prompt=""
            search_end_time = time.time()
            
    
        else: # Search task
            search_start_time = time.time()
            if len(user_input)<400 and "context:" not in user_input and "Context:" not in user_input:
                results = await brave_search(user_input)
                if len(results)==0: # No information 
                    return "I'm sorry but I don't have any information on that.", ""
                documents = [desc.replace('<strong>', '').replace('</strong>', '') for _, desc, _ in results]
                context = "\n".join(documents)
            else:
                context="" # User provide context
            search_end_time = time.time()
        
    generation_start_time = time.time()
    response = await query_teapot(prompt, context, user_input)
    generation_end_time = time.time()
    
    debug_info = f"""
Prompt: 
{prompt}

Context: 
{context}

Query: 
{user_input}

Search time: {search_end_time - search_start_time:.2f} seconds
Generation time: {generation_end_time - generation_start_time:.2f} seconds
Response: {response}
"""
    
    return response, debug_info


st.write("418 I'm a teapot")

DISCORD_TOKEN = os.environ.get("discord_key")

# Create an instance of Intents and enable the required ones
intents = discord.Intents.default()  # Default intents enable basic functionality
intents.messages = True  # Enable message-related events

# Create an instance of a client with the intents
client = discord.Client(intents=intents)

# Event when the bot has connected to the server
@client.event
async def on_ready():
    print(f'Logged in as {client.user}')

# Event when a message is received
@client.event
async def on_message(message):
    # Check if the message is from the bot itself to prevent a loop
    if message.author == client.user:
        return
    
    # Exit the function if the bot is not mentioned
    if f'<@{client.user.id}>' not in message.content:
        return 
        
    print(message.content)

    is_debug = "debug:" in message.content or "Debug:" in message.content
    async with message.channel.typing():
        cleaned_message=message.content.replace("debug:", "").replace("Debug:","").replace(f'<@{client.user.id}>',"")
        response, debug_info = await handle_chat(cleaned_message)
        print(response)
        sent_message = await message.reply(response)
    
        # Create a thread from the sent message
        if is_debug:
            thread = await sent_message.create_thread(name=f"""Debug Thread: '{cleaned_message[:80]}'""", auto_archive_duration=60)
        
            # Send a message in the created thread
            await thread.send(debug_info)


@st.cache_resource
def initialize():
    st.session_state["initialized"] = True
    client.run(DISCORD_TOKEN)

    return 

initialize()