|
import os |
|
import base64 |
|
import requests |
|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
from dataclasses import dataclass |
|
import pytesseract |
|
from PIL import Image |
|
|
|
@dataclass |
|
class ChatMessage: |
|
"""Custom ChatMessage class since huggingface_hub doesn't provide one""" |
|
role: str |
|
content: str |
|
|
|
def to_dict(self): |
|
"""Converts ChatMessage to a dictionary for JSON serialization.""" |
|
return {"role": self.role, "content": self.content} |
|
|
|
class XylariaChat: |
|
def __init__(self): |
|
|
|
self.hf_token = os.getenv("HF_TOKEN") |
|
if not self.hf_token: |
|
raise ValueError("HuggingFace token not found in environment variables") |
|
|
|
|
|
self.client = InferenceClient( |
|
model="Qwen/QwQ-32B-Preview", |
|
api_key=self.hf_token |
|
) |
|
|
|
|
|
self.image_api_url = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large" |
|
self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"} |
|
|
|
|
|
self.conversation_history = [] |
|
self.persistent_memory = {} |
|
|
|
|
|
self.system_prompt = """You are a helpful and harmless assistant. You are Xylaria developed by Sk Md Saad Amin. You should think step-by-step. You should respond to image questions""" |
|
|
|
def store_information(self, key, value): |
|
"""Store important information in persistent memory""" |
|
self.persistent_memory[key] = value |
|
return f"Stored: {key} = {value}" |
|
|
|
def retrieve_information(self, key): |
|
"""Retrieve information from persistent memory""" |
|
return self.persistent_memory.get(key, "No information found for this key.") |
|
|
|
def reset_conversation(self): |
|
""" |
|
Completely reset the conversation history, persistent memory, |
|
and clear API-side memory |
|
""" |
|
|
|
self.conversation_history = [] |
|
self.persistent_memory.clear() |
|
|
|
|
|
try: |
|
self.client = InferenceClient( |
|
model="Qwen/QwQ-32B-Preview", |
|
api_key=self.hf_token |
|
) |
|
except Exception as e: |
|
print(f"Error resetting API client: {e}") |
|
|
|
return None |
|
|
|
def caption_image(self, image): |
|
""" |
|
Caption an uploaded image using Hugging Face API |
|
Args: |
|
image (str): Base64 encoded image or file path |
|
Returns: |
|
str: Image caption or error message |
|
""" |
|
try: |
|
|
|
if isinstance(image, str) and os.path.isfile(image): |
|
with open(image, "rb") as f: |
|
data = f.read() |
|
|
|
elif isinstance(image, str): |
|
|
|
if image.startswith('data:image'): |
|
image = image.split(',')[1] |
|
data = base64.b64decode(image) |
|
|
|
else: |
|
data = image.read() |
|
|
|
|
|
response = requests.post( |
|
self.image_api_url, |
|
headers=self.image_api_headers, |
|
data=data |
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
caption = response.json()[0].get('generated_text', 'No caption generated') |
|
return caption |
|
else: |
|
return f"Error captioning image: {response.status_code} - {response.text}" |
|
|
|
except Exception as e: |
|
return f"Error processing image: {str(e)}" |
|
|
|
def perform_math_ocr(self, image_path): |
|
""" |
|
Perform OCR on an image and return the extracted text. |
|
|
|
Args: |
|
image_path (str): Path to the image file. |
|
|
|
Returns: |
|
str: Extracted text from the image, or an error message. |
|
""" |
|
try: |
|
|
|
img = Image.open(image_path) |
|
|
|
|
|
text = pytesseract.image_to_string(img) |
|
|
|
|
|
return text.strip() |
|
|
|
except Exception as e: |
|
return f"Error during Math OCR: {e}" |
|
|
|
def get_response(self, user_input, image=None): |
|
""" |
|
Generate a response using chat completions with improved error handling |
|
Args: |
|
user_input (str): User's message |
|
image (optional): Uploaded image |
|
Returns: |
|
Stream of chat completions or error message |
|
""" |
|
try: |
|
|
|
messages = [] |
|
|
|
|
|
messages.append(ChatMessage( |
|
role="system", |
|
content=self.system_prompt |
|
).to_dict()) |
|
|
|
|
|
if self.persistent_memory: |
|
memory_context = "Remembered Information:\n" + "\n".join( |
|
[f"{k}: {v}" for k, v in self.persistent_memory.items()] |
|
) |
|
messages.append(ChatMessage( |
|
role="system", |
|
content=memory_context |
|
).to_dict()) |
|
|
|
|
|
for msg in self.conversation_history: |
|
messages.append(ChatMessage( |
|
role=msg['role'], |
|
content=msg['content'] |
|
).to_dict()) |
|
|
|
|
|
if image: |
|
image_caption = self.caption_image(image) |
|
user_input = f"Image description: {image_caption}\n\nUser's message: {user_input}" |
|
|
|
|
|
messages.append(ChatMessage( |
|
role="user", |
|
content=user_input |
|
).to_dict()) |
|
|
|
|
|
input_tokens = sum(len(msg['content'].split()) for msg in messages) |
|
max_new_tokens = 16384 - input_tokens - 50 |
|
|
|
|
|
max_new_tokens = min(max_new_tokens, 10020) |
|
|
|
|
|
stream = self.client.chat_completion( |
|
messages=messages, |
|
model="Qwen/QwQ-32B-Preview", |
|
temperature=0.7, |
|
max_tokens=max_new_tokens, |
|
top_p=0.9, |
|
stream=True |
|
) |
|
|
|
return stream |
|
|
|
except Exception as e: |
|
print(f"Detailed error in get_response: {e}") |
|
return f"Error generating response: {str(e)}" |
|
|
|
def messages_to_prompt(self, messages): |
|
""" |
|
Convert a list of ChatMessage dictionaries to a single prompt string. |
|
|
|
This is a simple implementation and you might need to adjust it |
|
based on the specific requirements of the model you are using. |
|
""" |
|
prompt = "" |
|
for msg in messages: |
|
if msg["role"] == "system": |
|
prompt += f"<|system|>\n{msg['content']}<|end|>\n" |
|
elif msg["role"] == "user": |
|
prompt += f"<|user|>\n{msg['content']}<|end|>\n" |
|
elif msg["role"] == "assistant": |
|
prompt += f"<|assistant|>\n{msg['content']}<|end|>\n" |
|
prompt += "<|assistant|>\n" |
|
return prompt |
|
|
|
|
|
def create_interface(self): |
|
def streaming_response(message, chat_history, image_filepath, math_ocr_image_path): |
|
|
|
ocr_text = "" |
|
if math_ocr_image_path: |
|
ocr_text = self.perform_math_ocr(math_ocr_image_path) |
|
if ocr_text.startswith("Error"): |
|
|
|
updated_history = chat_history + [[message, ocr_text]] |
|
yield "", updated_history, None, None |
|
return |
|
else: |
|
message = f"Math OCR Result: {ocr_text}\n\nUser's message: {message}" |
|
|
|
|
|
if image_filepath: |
|
response_stream = self.get_response(message, image_filepath) |
|
else: |
|
response_stream = self.get_response(message) |
|
|
|
|
|
|
|
if isinstance(response_stream, str): |
|
|
|
updated_history = chat_history + [[message, response_stream]] |
|
yield "", updated_history, None, None |
|
return |
|
|
|
|
|
full_response = "" |
|
updated_history = chat_history + [[message, ""]] |
|
|
|
|
|
try: |
|
for chunk in response_stream: |
|
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: |
|
chunk_content = chunk.choices[0].delta.content |
|
full_response += chunk_content |
|
|
|
|
|
updated_history[-1][1] = full_response |
|
yield "", updated_history, None, None |
|
except Exception as e: |
|
print(f"Streaming error: {e}") |
|
|
|
updated_history[-1][1] = f"Error during response: {e}" |
|
yield "", updated_history, None, None |
|
return |
|
|
|
|
|
self.conversation_history.append( |
|
{"role": "user", "content": message} |
|
) |
|
self.conversation_history.append( |
|
{"role": "assistant", "content": full_response} |
|
) |
|
|
|
|
|
if len(self.conversation_history) > 10: |
|
self.conversation_history = self.conversation_history[-10:] |
|
|
|
|
|
custom_css = """ |
|
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); |
|
body, .gradio-container { |
|
font-family: 'Inter', sans-serif !important; |
|
} |
|
.chatbot-container .message { |
|
font-family: 'Inter', sans-serif !important; |
|
} |
|
.gradio-container input, |
|
.gradio-container textarea, |
|
.gradio-container button { |
|
font-family: 'Inter', sans-serif !important; |
|
} |
|
/* Image Upload Styling */ |
|
.image-container { |
|
border: 1px solid #ccc; |
|
border-radius: 8px; |
|
padding: 10px; |
|
margin-bottom: 10px; |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
gap: 10px; |
|
background-color: #f8f8f8; |
|
} |
|
.image-preview { |
|
max-width: 200px; |
|
max-height: 200px; |
|
border-radius: 8px; |
|
} |
|
.image-buttons { |
|
display: flex; |
|
gap: 10px; |
|
} |
|
.image-buttons button { |
|
padding: 8px 15px; |
|
border-radius: 5px; |
|
background-color: #4CAF50; |
|
color: white; |
|
border: none; |
|
cursor: pointer; |
|
} |
|
.image-buttons button:hover { |
|
background-color: #367c39; |
|
} |
|
""" |
|
|
|
with gr.Blocks(theme='soft', css=custom_css) as demo: |
|
|
|
with gr.Column(): |
|
chatbot = gr.Chatbot( |
|
label="Xylaria 1.5 Senoa (EXPERIMENTAL)", |
|
height=500, |
|
show_copy_button=True, |
|
) |
|
|
|
|
|
with gr.Accordion("Image Input", open=False): |
|
with gr.Column() as image_container: |
|
img = gr.Image( |
|
sources=["upload", "webcam"], |
|
type="filepath", |
|
label="", |
|
elem_classes="image-preview", |
|
) |
|
with gr.Row(): |
|
clear_image_btn = gr.Button("Clear Image") |
|
|
|
with gr.Accordion("Math Input", open=False): |
|
with gr.Column(): |
|
math_ocr_img = gr.Image( |
|
sources=["upload", "webcam"], |
|
type="filepath", |
|
label="Upload Image for math", |
|
elem_classes="image-preview" |
|
) |
|
with gr.Row(): |
|
clear_math_ocr_btn = gr.Button("Clear Math Image") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
txt = gr.Textbox( |
|
show_label=False, |
|
placeholder="Type your message...", |
|
container=False |
|
) |
|
btn = gr.Button("Send", scale=1) |
|
|
|
|
|
with gr.Row(): |
|
clear = gr.Button("Clear Conversation") |
|
clear_memory = gr.Button("Clear Memory") |
|
|
|
|
|
clear_image_btn.click( |
|
fn=lambda: None, |
|
inputs=None, |
|
outputs=[img], |
|
queue=False |
|
) |
|
|
|
|
|
clear_math_ocr_btn.click( |
|
fn=lambda: None, |
|
inputs=None, |
|
outputs=[math_ocr_img], |
|
queue=False |
|
) |
|
|
|
|
|
btn.click( |
|
fn=streaming_response, |
|
inputs=[txt, chatbot, img, math_ocr_img], |
|
outputs=[txt, chatbot, img, math_ocr_img] |
|
) |
|
txt.submit( |
|
fn=streaming_response, |
|
inputs=[txt, chatbot, img, math_ocr_img], |
|
outputs=[txt, chatbot, img, math_ocr_img] |
|
) |
|
|
|
|
|
clear.click( |
|
fn=lambda: None, |
|
inputs=None, |
|
outputs=[chatbot], |
|
queue=False |
|
) |
|
|
|
|
|
clear_memory.click( |
|
fn=self.reset_conversation, |
|
inputs=None, |
|
outputs=[chatbot], |
|
queue=False |
|
) |
|
|
|
|
|
demo.load(self.reset_conversation, None, None) |
|
|
|
return demo |
|
|
|
|
|
def main(): |
|
chat = XylariaChat() |
|
interface = chat.create_interface() |
|
interface.launch( |
|
share=True, |
|
debug=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |