|
|
|
|
|
|
|
import streamlit as st |
|
import os |
|
from openai import OpenAI |
|
import tempfile |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_community.document_loaders import ( |
|
PyPDFLoader, |
|
TextLoader, |
|
CSVLoader |
|
) |
|
from datetime import datetime |
|
from pydub import AudioSegment |
|
import pytz |
|
import chromadb |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_community.document_loaders import PyPDFLoader, TextLoader, CSVLoader |
|
import os |
|
import tempfile |
|
from datetime import datetime |
|
import pytz |
|
from langgraph.graph import StateGraph, START, END, add_messages |
|
from langgraph.constants import Send |
|
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage |
|
from pydantic import BaseModel |
|
from typing import List, Annotated, Any |
|
import re, operator |
|
|
|
|
|
chromadb.api.client.SharedSystemClient.clear_system_cache() |
|
|
|
class MultiAgentState(BaseModel): |
|
state: List[str] = [] |
|
messages: Annotated[list[AnyMessage], add_messages] |
|
topic: List[str] = [] |
|
context: List[str] = [] |
|
sub_topic_list: List[str] = [] |
|
sub_topics: Annotated[list[AnyMessage], add_messages] |
|
stories: Annotated[list[AnyMessage], add_messages] |
|
stories_lst: Annotated[list, operator.add] |
|
|
|
class StoryState(BaseModel): |
|
retrieved_docs: List[Any] = [] |
|
reranked_docs: List[str] = [] |
|
stories: Annotated[list[AnyMessage], add_messages] |
|
story_topic: str = "" |
|
stories_lst: Annotated[list, operator.add] |
|
|
|
class DocumentRAG: |
|
def __init__(self, embedding_choice="OpenAI"): |
|
self.document_store = None |
|
self.qa_chain = None |
|
self.document_summary = "" |
|
self.chat_history = [] |
|
self.last_processed_time = None |
|
self.api_key = os.getenv("OPENAI_API_KEY") |
|
self.init_time = datetime.now(pytz.UTC) |
|
self.embedding_choice = embedding_choice |
|
|
|
|
|
if self.embedding_choice == "Cohere": |
|
from langchain_cohere import ChatCohere |
|
import cohere |
|
self.llm = ChatCohere( |
|
model="command-r-plus-08-2024", |
|
temperature=0.7, |
|
cohere_api_key=os.getenv("COHERE_API_KEY") |
|
) |
|
self.cohere_client = cohere.Client(os.getenv("COHERE_API_KEY")) |
|
else: |
|
self.llm = ChatOpenAI( |
|
model_name="gpt-4", |
|
temperature=0.7, |
|
api_key=self.api_key |
|
) |
|
|
|
|
|
self.chroma_persist_dir = "./chroma_storage" |
|
os.makedirs(self.chroma_persist_dir, exist_ok=True) |
|
|
|
|
|
def _get_embedding_model(self): |
|
if not self.api_key: |
|
raise ValueError("API Key not found. Make sure to set the 'OPENAI_API_KEY' environment variable.") |
|
|
|
if self.embedding_choice == "OpenAI": |
|
return OpenAIEmbeddings(api_key=self.api_key) |
|
else: |
|
from langchain.embeddings import CohereEmbeddings |
|
return CohereEmbeddings( |
|
model="embed-multilingual-light-v3.0", |
|
cohere_api_key=os.getenv("COHERE_API_KEY") |
|
) |
|
|
|
|
|
|
|
|
|
def process_documents(self, uploaded_files): |
|
"""Process uploaded files by saving them temporarily and extracting content.""" |
|
if not self.api_key: |
|
return "Please set the OpenAI API key in the environment variables." |
|
if not uploaded_files: |
|
return "Please upload documents first." |
|
|
|
try: |
|
documents = [] |
|
for uploaded_file in uploaded_files: |
|
|
|
temp_file_path = tempfile.NamedTemporaryFile( |
|
delete=False, suffix=os.path.splitext(uploaded_file.name)[1] |
|
).name |
|
with open(temp_file_path, "wb") as temp_file: |
|
temp_file.write(uploaded_file.read()) |
|
|
|
|
|
if temp_file_path.endswith('.pdf'): |
|
loader = PyPDFLoader(temp_file_path) |
|
elif temp_file_path.endswith('.txt'): |
|
loader = TextLoader(temp_file_path) |
|
elif temp_file_path.endswith('.csv'): |
|
loader = CSVLoader(temp_file_path) |
|
else: |
|
return f"Unsupported file type: {uploaded_file.name}" |
|
|
|
|
|
try: |
|
documents.extend(loader.load()) |
|
except Exception as e: |
|
return f"Error loading {uploaded_file.name}: {str(e)}" |
|
|
|
if not documents: |
|
return "No valid documents were processed. Please check your files." |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=200, |
|
length_function=len |
|
) |
|
documents = text_splitter.split_documents(documents) |
|
|
|
|
|
self.document_text = " ".join([doc.page_content for doc in documents]) |
|
|
|
|
|
embeddings = self._get_embedding_model() |
|
self.document_store = Chroma.from_documents( |
|
documents, |
|
embeddings, |
|
persist_directory=self.chroma_persist_dir |
|
) |
|
|
|
self.qa_chain = ConversationalRetrievalChain.from_llm( |
|
ChatOpenAI(temperature=0, model_name='gpt-4', api_key=self.api_key), |
|
self.document_store.as_retriever(search_kwargs={'k': 6}), |
|
return_source_documents=True, |
|
verbose=False |
|
) |
|
|
|
self.last_processed_time = datetime.now(pytz.UTC) |
|
return "Documents processed successfully!" |
|
except Exception as e: |
|
return f"Error processing documents: {str(e)}" |
|
|
|
def generate_summary(self, text, language): |
|
"""Generate a summary of the provided text focusing on specific sections in the specified language.""" |
|
if not self.api_key: |
|
return "API Key not set. Please set it in the environment variables." |
|
try: |
|
client = OpenAI(api_key=self.api_key) |
|
response = client.chat.completions.create( |
|
model="gpt-4", |
|
messages=[ |
|
{"role": "system", "content": f""" |
|
Summarize the following document focusing mainly on these sections: |
|
1. Abstract |
|
2. In the Introduction, specifically focus on the portion where the key contributions of the research paper are highlighted. |
|
3. Conclusion |
|
4. Limitations |
|
5. Future Work |
|
|
|
Ensure the summary is concise, logically ordered, and suitable for {language}. |
|
Provide 7-9 key points for discussion in a structured format."""}, |
|
{"role": "user", "content": text[:4000]} |
|
], |
|
temperature=0.3 |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
return f"Error generating summary: {str(e)}" |
|
|
|
def create_podcast(self, language): |
|
"""Generate a podcast script and audio based on doc summary in the specified language.""" |
|
if not self.document_summary: |
|
return "Please process documents before generating a podcast.", None |
|
|
|
if not self.api_key: |
|
return "Please set the OpenAI API key in the environment variables.", None |
|
|
|
try: |
|
client = OpenAI(api_key=self.api_key) |
|
|
|
|
|
script_response = client.chat.completions.create( |
|
model="gpt-4", |
|
messages=[ |
|
{"role": "system", "content": f""" |
|
You are a professional podcast producer. Create a 1-2 minute structured podcast dialogue in {language} |
|
based on the provided document summary. Follow this flow: |
|
1. Brief Introduction of the Topic |
|
2. Highlight the limitations of existing methods, the key contributions of the research paper, and its advantages over the current state of the art. |
|
3. Discuss Limitations of the research work. |
|
4. Present the Conclusion |
|
5. Mention Future Work |
|
|
|
Clearly label the dialogue as 'Host 1:' and 'Host 2:'. Maintain a tone that is engaging, conversational, |
|
and insightful, while ensuring the flow remains logical and natural. Include a well-structured opening |
|
to introduce the topic and a clear, thoughtful closing that provides a smooth conclusion, avoiding any |
|
abrupt endings.""" |
|
}, |
|
{"role": "user", "content": f""" |
|
Document Summary: {self.document_summary}"""} |
|
], |
|
temperature=0.7 |
|
) |
|
|
|
script = script_response.choices[0].message.content |
|
if not script: |
|
return "Error: Failed to generate podcast script.", None |
|
|
|
|
|
final_audio = AudioSegment.empty() |
|
is_first_speaker = True |
|
|
|
lines = [line.strip() for line in script.split("\n") if line.strip()] |
|
for line in lines: |
|
if ":" not in line: |
|
continue |
|
|
|
speaker, text = line.split(":", 1) |
|
if not text.strip(): |
|
continue |
|
|
|
try: |
|
voice = "nova" if is_first_speaker else "onyx" |
|
audio_response = client.audio.speech.create( |
|
model="tts-1", |
|
voice=voice, |
|
input=text.strip() |
|
) |
|
|
|
temp_audio_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") |
|
audio_response.stream_to_file(temp_audio_file.name) |
|
|
|
segment = AudioSegment.from_file(temp_audio_file.name) |
|
final_audio += segment |
|
final_audio += AudioSegment.silent(duration=300) |
|
|
|
is_first_speaker = not is_first_speaker |
|
except Exception as e: |
|
print(f"Error generating audio for line: {text}") |
|
print(f"Details: {e}") |
|
continue |
|
|
|
if len(final_audio) == 0: |
|
return "Error: No audio could be generated.", None |
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name |
|
final_audio.export(output_file, format="mp3") |
|
return script, output_file |
|
|
|
except Exception as e: |
|
return f"Error generating podcast: {str(e)}", None |
|
|
|
def handle_query(self, question, history, language): |
|
"""Handle user queries in the specified language.""" |
|
if not self.qa_chain: |
|
return history + [("System", "Please process the documents first.")] |
|
try: |
|
preface = ( |
|
f"Instruction: Respond in {language}. Be professional and concise, " |
|
f"keeping the response under 300 words. If you cannot provide an answer, say: " |
|
f'"I am not sure about this question. Please try asking something else."' |
|
) |
|
query = f"{preface}\nQuery: {question}" |
|
|
|
result = self.qa_chain({ |
|
"question": query, |
|
"chat_history": [(q, a) for q, a in history] |
|
}) |
|
|
|
if "answer" not in result: |
|
return history + [("System", "Sorry, an error occurred.")] |
|
|
|
history.append((question, result["answer"])) |
|
return history |
|
except Exception as e: |
|
return history + [("System", f"Error: {str(e)}")] |
|
|
|
def extract_subtopics(self, messages): |
|
text = "\n".join([msg.content for msg in messages]) |
|
return re.findall(r'- \*\*(.*?)\*\*', text) |
|
|
|
def beginner_topic(self, state: MultiAgentState): |
|
prompt = f"What are the beginner-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}?" |
|
msg = self.llm.invoke([SystemMessage("Suppose you're a middle grader..."), HumanMessage(prompt)]) |
|
return {"message": msg, "sub_topics": msg} |
|
|
|
def middle_topic(self, state: MultiAgentState): |
|
prompt = f"What are the middle-level topics for {', '.join(state.topic)} in {', '.join(state.context)}? Avoid previous." |
|
msg = self.llm.invoke([SystemMessage("Suppose you're a college student..."), HumanMessage(prompt)]) |
|
return {"message": msg, "sub_topics": msg} |
|
|
|
def advanced_topic(self, state: MultiAgentState): |
|
prompt = f"What are the advanced-level topics for {', '.join(state.topic)} in {', '.join(state.context)}? Avoid previous." |
|
msg = self.llm.invoke([SystemMessage("Suppose you're a teacher..."), HumanMessage(prompt)]) |
|
return {"message": msg, "sub_topics": msg} |
|
|
|
def topic_extractor(self, state: MultiAgentState): |
|
return {"sub_topic_list": self.extract_subtopics(state.sub_topics)} |
|
|
|
|
|
def retrieve_node(self, state: StoryState): |
|
embedding = self._get_embedding_model() |
|
retriever = Chroma( |
|
persist_directory=self.chroma_persist_dir, |
|
embedding_function=embedding |
|
).as_retriever(search_kwargs={"k": 20}) |
|
|
|
topic = state.story_topic |
|
query = f"information about {topic}" |
|
docs = retriever.get_relevant_documents(query) |
|
return {"retrieved_docs": docs, "question": query} |
|
|
|
|
|
|
|
|
|
def rerank_node(self, state: StoryState): |
|
topic = state.story_topic |
|
query = f"Rerank documents based on how well they explain the topic {topic}" |
|
docs = state.retrieved_docs |
|
texts = [doc.page_content for doc in docs] |
|
|
|
if not texts: |
|
return {"reranked_docs": [], "question": query} |
|
|
|
if self.embedding_choice == "Cohere" and hasattr(self, "cohere_client"): |
|
rerank_results = self.cohere_client.rerank( |
|
query=query, |
|
documents=texts, |
|
top_n=5, |
|
model="rerank-v3.5" |
|
) |
|
top_docs = [texts[result.index] for result in rerank_results.results] |
|
else: |
|
top_docs = sorted(texts, key=lambda t: -len(t))[:5] |
|
|
|
return {"reranked_docs": top_docs, "question": query} |
|
|
|
|
|
|
|
|
|
|
|
def generate_story_node(self, state: StoryState): |
|
context = "\n\n".join(state.reranked_docs) |
|
topic = state.story_topic |
|
|
|
system_message = f""" |
|
Suppose you're a brilliant science storyteller. |
|
You write stories that help middle schoolers understand complex science topics with fun and clarity. |
|
Add subtle humor and make it engaging. |
|
""" |
|
prompt = f""" |
|
Use the following context to write a fun and simple story explaining **{topic}** to a middle schooler:\n |
|
Context:\n{context}\n\n |
|
Story: |
|
""" |
|
|
|
msg = self.llm.invoke([SystemMessage(system_message), HumanMessage(prompt)]) |
|
return {"stories": msg} |
|
|
|
|
|
|
|
|
|
def run_multiagent_storygraph(self, topic: str, context: str): |
|
if self.embedding_choice == "OpenAI": |
|
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key) |
|
elif self.embedding_choice == "Cohere": |
|
from langchain_cohere import ChatCohere |
|
self.llm = ChatCohere( |
|
model="command-r-plus-08-2024", |
|
temperature=0.7, |
|
cohere_api_key=os.getenv("COHERE_API_KEY") |
|
) |
|
|
|
|
|
story_graph = StateGraph(StoryState) |
|
story_graph.add_node("Retrieve", self.retrieve_node) |
|
story_graph.add_node("Rerank", self.rerank_node) |
|
story_graph.add_node("Generate", self.generate_story_node) |
|
story_graph.set_entry_point("Retrieve") |
|
story_graph.add_edge("Retrieve", "Rerank") |
|
story_graph.add_edge("Rerank", "Generate") |
|
story_graph.set_finish_point("Generate") |
|
story_subgraph = story_graph.compile() |
|
|
|
|
|
graph = StateGraph(MultiAgentState) |
|
graph.add_node("beginner_topic", self.beginner_topic) |
|
graph.add_node("middle_topic", self.middle_topic) |
|
graph.add_node("advanced_topic", self.advanced_topic) |
|
graph.add_node("topic_extractor", self.topic_extractor) |
|
graph.add_node("story_generator", story_subgraph) |
|
|
|
graph.add_edge(START, "beginner_topic") |
|
graph.add_edge("beginner_topic", "middle_topic") |
|
graph.add_edge("middle_topic", "advanced_topic") |
|
graph.add_edge("advanced_topic", "topic_extractor") |
|
graph.add_conditional_edges( |
|
"topic_extractor", |
|
lambda state: [Send("story_generator", {"story_topic": t}) for t in state.sub_topic_list], |
|
["story_generator"] |
|
) |
|
graph.add_edge("story_generator", END) |
|
|
|
compiled = graph.compile(checkpointer=MemorySaver()) |
|
thread = {"configurable": {"thread_id": "storygraph-session"}} |
|
|
|
|
|
result = compiled.invoke({"topic": [topic], "context": [context]}, thread) |
|
|
|
|
|
if not result.get("sub_topic_list"): |
|
fallback_subs = ["Neural Networks", "Reinforcement Learning", "Supervised vs Unsupervised"] |
|
compiled.update_state(thread, {"sub_topic_list": fallback_subs}) |
|
result = compiled.invoke(None, thread, stream_mode="values") |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
st.title("About") |
|
st.markdown( |
|
""" |
|
This app is inspired by the [RAG_HW HuggingFace Space](https://huggingface.co/spaces/wint543/RAG_HW). |
|
It allows users to upload documents, generate summaries, ask questions, and create podcasts. |
|
""" |
|
) |
|
st.markdown("### Steps:") |
|
st.markdown("1. Upload documents.") |
|
st.markdown("2. Generate summary.") |
|
st.markdown("3. Ask questions.") |
|
st.markdown("4. Create podcast.") |
|
|
|
st.markdown("### Credits:") |
|
st.markdown("Image Source: [Geeksforgeeks](https://www.geeksforgeeks.org/how-to-convert-document-into-podcast/)") |
|
|
|
|
|
st.title("Document Analyzer & Podcast Generator") |
|
st.image("./cover_image.png", use_container_width=True) |
|
|
|
|
|
st.subheader("Embedding Model Selection") |
|
embedding_choice = st.radio( |
|
"Choose the embedding model for document processing and story generation:", |
|
["OpenAI", "Cohere"], |
|
horizontal=True, |
|
key="embedding_model" |
|
) |
|
|
|
if "rag_system" not in st.session_state: |
|
st.session_state.rag_system = DocumentRAG(embedding_choice=embedding_choice) |
|
elif st.session_state.rag_system.embedding_choice != embedding_choice: |
|
st.session_state.rag_system = DocumentRAG(embedding_choice=embedding_choice) |
|
|
|
|
|
|
|
st.subheader("Step 1: Upload and Process Documents") |
|
uploaded_files = st.file_uploader("Upload files (PDF, TXT, CSV)", accept_multiple_files=True) |
|
|
|
if st.button("Process Documents"): |
|
if uploaded_files: |
|
with st.spinner("Processing documents, please wait..."): |
|
result = st.session_state.rag_system.process_documents(uploaded_files) |
|
if "successfully" in result: |
|
st.success(result) |
|
else: |
|
st.error(result) |
|
else: |
|
st.warning("No files uploaded.") |
|
|
|
|
|
st.subheader("Step 2: Generate Summary") |
|
st.write("Select Summary Language:") |
|
summary_language_options = ["English", "Hindi", "Spanish", "French", "German", "Chinese", "Japanese"] |
|
summary_language = st.radio( |
|
"", |
|
summary_language_options, |
|
horizontal=True, |
|
key="summary_language" |
|
) |
|
|
|
if st.button("Generate Summary"): |
|
if hasattr(st.session_state.rag_system, "document_text") and st.session_state.rag_system.document_text: |
|
with st.spinner("Generating summary, please wait..."): |
|
summary = st.session_state.rag_system.generate_summary(st.session_state.rag_system.document_text, summary_language) |
|
if summary: |
|
st.session_state.rag_system.document_summary = summary |
|
st.text_area("Document Summary", summary, height=200) |
|
st.success("Summary generated successfully!") |
|
else: |
|
st.error("Failed to generate summary.") |
|
else: |
|
st.info("Please process documents first to generate summary.") |
|
|
|
|
|
st.subheader("Step 3: Ask Questions") |
|
st.write("Select Q&A Language:") |
|
qa_language_options = ["English", "Hindi", "Spanish", "French", "German", "Chinese", "Japanese"] |
|
qa_language = st.radio( |
|
"", |
|
qa_language_options, |
|
horizontal=True, |
|
key="qa_language" |
|
) |
|
|
|
if st.session_state.rag_system.qa_chain: |
|
history = [] |
|
user_question = st.text_input("Ask a question:") |
|
if st.button("Submit Question"): |
|
with st.spinner("Answering your question, please wait..."): |
|
history = st.session_state.rag_system.handle_query(user_question, history, qa_language) |
|
for question, answer in history: |
|
st.chat_message("user").write(question) |
|
st.chat_message("assistant").write(answer) |
|
else: |
|
st.info("Please process documents first to enable Q&A.") |
|
|
|
|
|
|
|
st.subheader("Step 5: Explore Subtopics via Multi-Agent Graph") |
|
story_topic = st.text_input("Enter main topic:", value="Machine Learning") |
|
story_context = st.text_input("Enter learning context:", value="Education") |
|
|
|
if st.button("Run Story Graph"): |
|
with st.spinner("Generating subtopics and stories..."): |
|
result = st.session_state.rag_system.run_multiagent_storygraph(topic=story_topic, context=story_context) |
|
|
|
subtopics = result.get("sub_topic_list", []) |
|
st.markdown("### π§ Extracted Subtopics") |
|
for sub in subtopics: |
|
st.markdown(f"- {sub}") |
|
|
|
stories = result.get("stories", []) |
|
if stories: |
|
st.markdown("### π Generated Stories") |
|
for i, story in enumerate(stories): |
|
st.markdown(f"**Story {i+1}:**") |
|
st.markdown(story.content) |
|
else: |
|
st.warning("No stories were generated.") |
|
|
|
|
|
st.subheader("Step 4: Generate Podcast") |
|
st.write("Select Podcast Language:") |
|
podcast_language_options = ["English", "Hindi", "Spanish", "French", "German", "Chinese", "Japanese"] |
|
podcast_language = st.radio( |
|
"", |
|
podcast_language_options, |
|
horizontal=True, |
|
key="podcast_language" |
|
) |
|
|
|
|
|
if st.session_state.rag_system.document_summary: |
|
if st.button("Generate Podcast"): |
|
with st.spinner("Generating podcast, please wait..."): |
|
script, audio_path = st.session_state.rag_system.create_podcast(podcast_language) |
|
if audio_path: |
|
st.text_area("Generated Podcast Script", script, height=200) |
|
st.audio(audio_path, format="audio/mp3") |
|
st.success("Podcast generated successfully! You can listen to it above.") |
|
else: |
|
st.error(script) |
|
else: |
|
st.info("Please process documents and generate summary before creating a podcast.") |
|
|