|
import os |
|
import base64 |
|
import json |
|
import requests |
|
from typing import Dict, Any, Optional, Union |
|
from pathlib import Path |
|
import asyncio |
|
from mistralai.extra.mcp.sse import MCPClientSSE, SSEServerParams |
|
from mistralai import Mistral |
|
from mistralai.models import UserMessage, AssistantMessage, ToolMessage |
|
from pydantic import BaseModel |
|
from IPython.display import Audio, display |
|
import platform |
|
import subprocess |
|
import urllib.parse |
|
from gtts import gTTS |
|
|
|
|
|
class AnalysisDescription(BaseModel): |
|
document_type: str |
|
key_findings: list[str] |
|
summary: str |
|
metadata: Dict[str, Any] |
|
confidence_score: float |
|
|
|
MODEL = "mistral-large-latest" |
|
|
|
def play_wav(url: str, save_path: str = "/tmp/audio.wav"): |
|
try: |
|
if url.startswith("file://"): |
|
file_path = urllib.parse.urlparse(url).path |
|
file_path = urllib.parse.unquote(file_path.lstrip("/")) |
|
else: |
|
response = requests.get(url, timeout=10) |
|
response.raise_for_status() |
|
with open(save_path, 'wb') as f: |
|
f.write(response.content) |
|
file_path = save_path |
|
return file_path |
|
except Exception as e: |
|
print(f"Error handling audio: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
|
|
def create_doc_agent(client: Mistral): |
|
return client.beta.agents.create( |
|
model=MODEL, |
|
name="DocAgent", |
|
description="Converts OCR PDFs to JSON using document processing capabilities", |
|
instructions="Process documents by extracting text and structure, then convert to JSON format. Focus on climate-related documents and extract key data points.", |
|
tools=[ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "process_climate_document", |
|
"description": "Process climate documents from file path or URL and extract structured data", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"file_path": { |
|
"type": "string", |
|
"description": "Path to the document file" |
|
}, |
|
"url": { |
|
"type": "string", |
|
"description": "URL to the document" |
|
}, |
|
"document_type": { |
|
"type": "string", |
|
"description": "Type of climate document (report, analysis, data, etc.)" |
|
} |
|
} |
|
} |
|
} |
|
} |
|
] |
|
) |
|
|
|
|
|
def create_image_agent(client: Mistral): |
|
return client.beta.agents.create( |
|
model=MODEL, |
|
name="ImageAgent", |
|
description="Converts image PDFs to JSON using image analysis capabilities", |
|
instructions="Analyze image-based documents, extract text and visual elements, then structure the data as JSON. Handle charts, graphs, and tabular data effectively.", |
|
tools=[ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "analyze_image", |
|
"description": "Analyze image documents and extract structured data", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"image_data": { |
|
"type": "string", |
|
"description": "Base64-encoded image data" |
|
}, |
|
"image_format": { |
|
"type": "string", |
|
"description": "Image format (png, jpg, pdf, etc.)" |
|
}, |
|
"analysis_focus": { |
|
"type": "string", |
|
"description": "Specific focus for analysis (text_extraction, chart_analysis, table_extraction)" |
|
} |
|
}, |
|
"required": ["image_data", "image_format"] |
|
} |
|
} |
|
} |
|
] |
|
) |
|
|
|
|
|
def create_json_analyzer_agent(client: Mistral): |
|
return client.beta.agents.create( |
|
model=MODEL, |
|
name="JsonAnalyzerAgent", |
|
description="Analyzes JSON outputs from DocAgent or ImageAgent, producing detailed descriptions", |
|
instructions="Analyze JSON data structures, identify patterns, extract insights, and provide comprehensive analysis. Output should be structured and detailed.", |
|
tools=[ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "analyze_json_data", |
|
"description": "Process and analyze JSON data to extract insights and patterns", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"json_data": { |
|
"type": "object", |
|
"description": "JSON data to analyze" |
|
}, |
|
"analysis_type": { |
|
"type": "string", |
|
"description": "Type of analysis to perform (statistical, content, structural)" |
|
} |
|
}, |
|
"required": ["json_data"] |
|
} |
|
} |
|
} |
|
] |
|
) |
|
|
|
def create_speech_agent(client: Mistral): |
|
return client.beta.agents.create( |
|
model=MODEL, |
|
name="SpeechAgent", |
|
description="Converts text analysis from JsonAnalyzerAgent into speech", |
|
instructions="Convert text analysis into natural speech format. Optimize text for spoken delivery and handle technical content appropriately.", |
|
tools=[ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "text_to_speech", |
|
"description": "Convert text to speech audio", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"text": { |
|
"type": "string", |
|
"description": "Text to convert to speech" |
|
}, |
|
"voice_settings": { |
|
"type": "object", |
|
"properties": { |
|
"speed": {"type": "number", "default": 1.0}, |
|
"pitch": {"type": "number", "default": 1.0}, |
|
"voice_type": {"type": "string", "default": "neutral"} |
|
} |
|
} |
|
}, |
|
"required": ["text"] |
|
} |
|
} |
|
} |
|
] |
|
) |
|
|
|
|
|
def simulate_process_climate_document(file_path: Optional[str] = None, url: Optional[str] = None, document_type: str = "report") -> Dict[str, Any]: |
|
"""Simulate document processing function""" |
|
return { |
|
"document_id": "doc_001", |
|
"source": file_path or url, |
|
"type": document_type, |
|
"extracted_text": "Climate change impacts are increasing globally...", |
|
"key_data": { |
|
"temperature_increase": "1.5°C", |
|
"co2_levels": "420ppm", |
|
"affected_regions": ["Arctic", "Coastal Areas", "Tropical Regions"] |
|
}, |
|
"metadata": { |
|
"pages": 45, |
|
"extraction_confidence": 0.92, |
|
"processing_time": "2.3s" |
|
} |
|
} |
|
|
|
def simulate_analyze_image(image_data: str, image_format: str, analysis_focus: str = "text_extraction") -> Dict[str, Any]: |
|
"""Simulate image analysis function""" |
|
return { |
|
"image_id": "img_001", |
|
"format": image_format, |
|
"analysis_type": analysis_focus, |
|
"extracted_content": { |
|
"text": "Global Temperature Anomalies 2020-2024", |
|
"charts": ["line_chart_temperatures", "bar_chart_emissions"], |
|
"tables": [{"headers": ["Year", "Temperature", "Anomaly"], "rows": 5}] |
|
}, |
|
"visual_elements": { |
|
"charts_detected": 2, |
|
"tables_detected": 1, |
|
"text_regions": 8 |
|
}, |
|
"confidence": 0.88 |
|
} |
|
|
|
def simulate_analyze_json_data(json_data: Dict[str, Any], analysis_type: str = "content") -> Dict[str, Any]: |
|
"""Simulate JSON analysis function""" |
|
return { |
|
"analysis_summary": "Comprehensive climate document analysis completed", |
|
"key_insights": [ |
|
"Temperature data shows accelerating warming trend", |
|
"Regional variations indicate uneven climate impacts", |
|
"Emission data correlates with temperature increases" |
|
], |
|
"data_quality": { |
|
"completeness": 0.91, |
|
"consistency": 0.87, |
|
"reliability": 0.89 |
|
}, |
|
"recommendations": [ |
|
"Focus on high-impact regions for intervention", |
|
"Monitor temperature trends quarterly", |
|
"Implement emission reduction strategies" |
|
] |
|
} |
|
|
|
def simulate_text_to_speech(text: str, voice_settings: Dict[str, Any] = None) -> str: |
|
print(f"Converting to speech: {text[:100]}...") |
|
save_path = "/tmp/generated_speech.wav" |
|
try: |
|
tts = gTTS(text=text, lang="en") |
|
tts.save(save_path) |
|
if not os.path.exists(save_path): |
|
raise FileNotFoundError(f"Failed to save audio to {save_path}") |
|
return f"file://{os.path.abspath(save_path)}" |
|
except Exception as e: |
|
print(f"Error generating speech: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
async def process_document_workflow(client: Mistral, file_path: str, document_type: str = "climate_report"): |
|
print("Starting document processing workflow...") |
|
|
|
try: |
|
|
|
doc_tool = [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "process_climate_document", |
|
"description": "Process climate documents from file path or URL and extract structured data", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"file_path": {"type": "string", "description": "Path to the document file"}, |
|
"url": {"type": "string", "description": "URL to the document"}, |
|
"document_type": {"type": "string", "description": "Type of climate document"} |
|
} |
|
} |
|
} |
|
} |
|
] |
|
|
|
messages = [ |
|
UserMessage(content=f"Process the climate document at {file_path} of type {document_type}") |
|
] |
|
|
|
response = await client.chat.complete_async( |
|
model=MODEL, |
|
messages=messages, |
|
tools=doc_tool |
|
) |
|
|
|
print("Document processing response:") |
|
print(response.choices[0].message.content) |
|
|
|
if response.choices[0].message.tool_calls: |
|
for tool_call in response.choices[0].message.tool_calls: |
|
if tool_call.function.name == "process_climate_document": |
|
doc_result = simulate_process_climate_document(file_path=file_path, document_type=document_type) |
|
print("Document processing result:") |
|
print(json.dumps(doc_result, indent=2)) |
|
|
|
return response |
|
|
|
except Exception as e: |
|
print(f"Error in document workflow: {str(e)}") |
|
return None |
|
|
|
async def process_image_workflow(client: Mistral, image_path: str, analysis_focus: str = "text_extraction"): |
|
print("Starting image processing workflow...") |
|
|
|
try: |
|
|
|
if not os.path.exists(image_path): |
|
raise FileNotFoundError(f"Image file not found: {image_path}") |
|
|
|
|
|
with open(image_path, "rb") as image_file: |
|
image_data = base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
|
image_tool = [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "analyze_image", |
|
"description": "Analyze image documents and extract structured data", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"image_data": {"type": "string", "description": "Base64-encoded image data"}, |
|
"image_format": {"type": "string", "description": "Image format (png, jpg, pdf, etc.)"}, |
|
"analysis_focus": {"type": "string", "description": "Specific focus for analysis"} |
|
}, |
|
"required": ["image_data", "image_format"] |
|
} |
|
} |
|
} |
|
] |
|
|
|
messages = [ |
|
UserMessage(content=f"Analyze the image document at {image_path} with focus on {analysis_focus}") |
|
] |
|
|
|
response = await client.chat.complete_async( |
|
model=MODEL, |
|
messages=messages, |
|
tools=image_tool |
|
) |
|
|
|
print("Image processing response:") |
|
print(response.choices[0].message.content) |
|
|
|
if response.choices[0].message.tool_calls: |
|
for tool_call in response.choices[0].message.tool_calls: |
|
if tool_call.function.name == "analyze_image": |
|
image_result = simulate_analyze_image( |
|
image_data=image_data, |
|
image_format="jpg", |
|
analysis_focus=analysis_focus |
|
) |
|
print("Image analysis result:") |
|
print(json.dumps(image_result, indent=2)) |
|
|
|
return response |
|
|
|
except Exception as e: |
|
print(f"Error in image workflow: {str(e)}") |
|
return None |
|
|
|
async def complete_analysis_workflow(client: Mistral, input_data: Dict[str, Any], max_retries: int = 3, initial_delay: float = 5.0): |
|
print("Starting complete analysis workflow...") |
|
|
|
async def make_api_call(messages, tools, retry_count=0): |
|
try: |
|
response = await client.chat.complete_async( |
|
model=MODEL, |
|
messages=messages, |
|
tools=tools |
|
) |
|
return response |
|
except Exception as e: |
|
if "429" in str(e) and retry_count < max_retries: |
|
delay = initial_delay * (2 ** retry_count) |
|
print(f"Rate limit hit, retrying in {delay} seconds... (Attempt {retry_count + 1}/{max_retries})") |
|
await asyncio.sleep(delay) |
|
return await make_api_call(messages, tools, retry_count + 1) |
|
raise e |
|
|
|
try: |
|
|
|
json_analysis_tool = [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "analyze_json_data", |
|
"description": "Process and analyze JSON data to extract insights and patterns", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"json_data": {"type": "object", "description": "JSON data to analyze"}, |
|
"analysis_type": {"type": "string", "description": "Type of analysis to perform"} |
|
}, |
|
"required": ["json_data"] |
|
} |
|
} |
|
} |
|
] |
|
|
|
|
|
messages = [ |
|
UserMessage(content="Analyze the provided JSON data and create a comprehensive analysis") |
|
] |
|
|
|
json_response = await make_api_call(messages, json_analysis_tool) |
|
|
|
print("JSON Analysis response:") |
|
print(json_response.choices[0].message.content) |
|
|
|
|
|
if json_response.choices[0].message.tool_calls: |
|
for tool_call in json_response.choices[0].message.tool_calls: |
|
if tool_call.function.name == "analyze_json_data": |
|
analysis_result = simulate_analyze_json_data(json_data=input_data) |
|
print("Analysis result:") |
|
print(json.dumps(analysis_result, indent=2)) |
|
|
|
|
|
await asyncio.sleep(2.0) |
|
|
|
|
|
speech_tool = [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "text_to_speech", |
|
"description": "Convert text to speech audio", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"text": {"type": "string", "description": "Text to convert to speech"}, |
|
"voice_settings": { |
|
"type": "object", |
|
"properties": { |
|
"speed": {"type": "number", "default": 1.0}, |
|
"pitch": {"type": "number", "default": 1.0}, |
|
"voice_type": {"type": "string", "default": "neutral"} |
|
} |
|
} |
|
}, |
|
"required": ["text"] |
|
} |
|
} |
|
} |
|
] |
|
|
|
|
|
analysis_text = "Climate analysis reveals significant warming trends with regional variations requiring immediate attention." |
|
|
|
speech_messages = [ |
|
UserMessage(content=f"Convert this analysis to speech: {analysis_text}") |
|
] |
|
|
|
speech_response = await make_api_call(speech_messages, speech_tool) |
|
|
|
print("Speech conversion response:") |
|
print(speech_response.choices[0].message.content) |
|
|
|
|
|
if speech_response.choices[0].message.tool_calls: |
|
for tool_call in speech_response.choices[0].message.tool_calls: |
|
if tool_call.function.name == "text_to_speech": |
|
audio_url = simulate_text_to_speech(text=analysis_text) |
|
print(f"Generated audio URL: {audio_url}") |
|
|
|
|
|
play_result = play_wav(audio_url) |
|
print(f"Audio play result: {play_result}") |
|
|
|
return json_response, speech_response |
|
|
|
except Exception as e: |
|
print(f"Error in complete analysis workflow: {str(e)}") |
|
return None, None |
|
|
|
async def tts_with_mcp(client: Mistral, text: str = "hello, and good luck for the hackathon"): |
|
try: |
|
|
|
tts_tool = [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "text_to_speech", |
|
"description": "Convert text to speech audio", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"text": {"type": "string", "description": "Text to convert to speech"}, |
|
"voice_settings": { |
|
"type": "object", |
|
"properties": { |
|
"speed": {"type": "number", "default": 1.0}, |
|
"pitch": {"type": "number", "default": 1.0}, |
|
"voice_type": {"type": "string", "default": "neutral"} |
|
} |
|
} |
|
}, |
|
"required": ["text"] |
|
} |
|
} |
|
} |
|
] |
|
|
|
print("Running TTS workflow...") |
|
messages = [ |
|
UserMessage(content=f"Say '{text}' out loud!") |
|
] |
|
|
|
response = await client.chat.complete_async( |
|
model=MODEL, |
|
messages=messages, |
|
tools=tts_tool |
|
) |
|
|
|
print("TTS Agent response:") |
|
print(response.choices[0].message.content) |
|
|
|
if response.choices[0].message.tool_calls: |
|
for tool_call in response.choices[0].message.tool_calls: |
|
if tool_call.function.name == "text_to_speech": |
|
audio_url = simulate_text_to_speech(text=text) |
|
print(f"Generated audio URL: {audio_url}") |
|
play_result = play_wav(audio_url) |
|
print(f"Audio play result: {play_result}") |
|
|
|
return response |
|
|
|
except Exception as e: |
|
print(f"Error in TTS workflow: {str(e)}") |
|
return None |
|
|
|
async def main(client: Mistral): |
|
print("Running TTS workflow...") |
|
|
|
try: |
|
|
|
text = "hello, and good luck for the hackathon" |
|
save_path = "/tmp/output.wav" |
|
tts = gTTS(text=text, lang="en") |
|
tts.save(save_path) |
|
print(f"Audio saved to {save_path}") |
|
|
|
|
|
play_result = play_wav(f"file://{os.path.abspath(save_path)}") |
|
print(f"Audio play result: {play_result}") |
|
|
|
|
|
run_result = await tts_with_mcp(client, text) |
|
|
|
if run_result: |
|
print("All run entries:") |
|
for entry in run_result.choices[0].message.content.splitlines(): |
|
print(entry) |
|
|
|
return run_result |
|
|
|
except Exception as e: |
|
print(f"Error in TTS workflow: {str(e)}") |
|
return None |
|
|
|
async def main_workflow(client: Mistral): |
|
print("Mistral Multi-Agent Document Processing System Initialized") |
|
doc_agent = create_doc_agent(client) |
|
image_agent = create_image_agent(client) |
|
json_analyzer_agent = create_json_analyzer_agent(client) |
|
speech_agent = create_speech_agent(client) |
|
|
|
print("Available agents:") |
|
print(f"- DocAgent ID: {doc_agent.id}") |
|
print(f"- ImageAgent ID: {image_agent.id}") |
|
print(f"- JsonAnalyzerAgent ID: {json_analyzer_agent.id}") |
|
print(f"- SpeechAgent ID: {speech_agent.id}") |
|
print("-" * 50) |
|
|
|
|
|
print("Skipping hardcoded document and image processing workflows in main_workflow.") |
|
print("Use the Gradio interface to upload and process files.") |
|
print("-" * 50) |
|
|
|
|
|
print("3. Running complete analysis workflow...") |
|
sample_data = { |
|
"temperature_data": [20.1, 20.5, 21.2, 21.8], |
|
"emissions": [400, 410, 415, 420], |
|
"regions": ["Global", "Arctic", "Tropical"] |
|
} |
|
analysis_response, speech_response = await complete_analysis_workflow(client, sample_data) |
|
print("-" * 50) |
|
|
|
if analysis_response: |
|
print("Analysis Response:") |
|
print(analysis_response.choices[0].message.content) |
|
else: |
|
print("No analysis response received") |
|
|
|
if speech_response: |
|
print("Speech Response:") |
|
print(speech_response.choices[0].message.content) |
|
else: |
|
print("No speech response received") |
|
|
|
print("All workflows completed!") |
|
|
|
async def full_run(client: Mistral): |
|
await main_workflow(client) |
|
print("\n" + "="*50) |
|
print("Running TTS workflow...") |
|
await main(client) |
|
|
|
if __name__ == "__main__": |
|
|
|
client = Mistral(api_key="YOUR_API_KEY") |
|
asyncio.run(full_run(client)) |