Spaces:
Sleeping
Sleeping
import os | |
import gc | |
import tempfile | |
import uuid | |
import logging | |
import streamlit as st | |
from dotenv import load_dotenv | |
from gitingest import ingest | |
from llama_index.core import Settings, PromptTemplate, VectorStoreIndex, SimpleDirectoryReader | |
from llama_index.core.node_parser import MarkdownNodeParser | |
from llama_index.llms.sambanovasystems import SambaNovaCloud | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
# Load environment variables from .env | |
load_dotenv() | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Custom exception for application errors | |
class GitHubRAGError(Exception): | |
"""Custom exception for GitHub RAG application errors""" | |
pass | |
# Fetch API key for SambaNova | |
SAMBANOVA_API_KEY = os.getenv("SAMBANOVA_API_KEY") | |
if not SAMBANOVA_API_KEY: | |
raise ValueError("SAMBANOVA_API_KEY is not set in environment variables") | |
# Initialize Streamlit session state | |
if "id" not in st.session_state: | |
st.session_state.id = uuid.uuid4() | |
st.session_state.file_cache = {} | |
st.session_state.messages = [] | |
session_id = st.session_state.id | |
def load_llm(): | |
""" | |
Load and cache the SambaNova LLM predictor | |
""" | |
return SambaNovaCloud( | |
api_key=SAMBANOVA_API_KEY, | |
model="DeepSeek-R1-Distill-Llama-70B", | |
temperature=0.1, | |
top_p=0.1, | |
) | |
def reset_chat(): | |
"""Clear chat history and free resources""" | |
st.session_state.messages = [] | |
gc.collect() | |
def process_with_gitingets(github_url: str): | |
"""Use gitingest to fetch and summarize the GitHub repository""" | |
summary, tree, content = ingest(github_url) | |
return summary, tree, content | |
# --- Sidebar: Load Repository --- | |
with st.sidebar: | |
st.header("Add your GitHub repository!") | |
github_url = st.text_input( | |
"GitHub repo URL", placeholder="https://github.com/user/repo" | |
) | |
load_btn = st.button("Load Repository") | |
if github_url and load_btn: | |
try: | |
repo_name = github_url.rstrip("/").split("/")[-1] | |
cache_key = f"{session_id}-{repo_name}" | |
# Only process if not cached | |
if cache_key not in st.session_state.file_cache: | |
with st.spinner("Processing repository..."): | |
summary, tree, content = process_with_gitingets(github_url) | |
with tempfile.TemporaryDirectory() as tmpdir: | |
md_path = os.path.join(tmpdir, f"{repo_name}.md") | |
with open(md_path, "w", encoding="utf-8") as f: | |
f.write(content) | |
loader = SimpleDirectoryReader(input_dir=tmpdir) | |
docs = loader.load_data() | |
embed_model = HuggingFaceEmbedding( | |
model_name="nomic-ai/nomic-embed-text-v2-moe", | |
trust_remote_code=True, | |
) | |
Settings.embed_model = embed_model | |
llm_predictor = load_llm() | |
Settings.llm = llm_predictor | |
node_parser = MarkdownNodeParser() | |
index = VectorStoreIndex.from_documents( | |
documents=docs, | |
transformations=[node_parser], | |
show_progress=True, | |
) | |
qa_prompt = PromptTemplate( | |
"You are an AI assistant specialized in analyzing GitHub repositories.\n" | |
"Repository structure:\n{tree}\n---\n" | |
"Context:\n{context_str}\n---\n" | |
"Question: {query_str}\nAnswer:" | |
) | |
query_engine = index.as_query_engine(streaming=True) | |
query_engine.update_prompts({ | |
"response_synthesizer:text_qa_template": qa_prompt | |
}) | |
st.session_state.file_cache[cache_key] = (query_engine, tree) | |
st.success("Repository loaded and indexed. Ready to chat!") | |
else: | |
st.info("Repository already loaded.") | |
except Exception as e: | |
st.error(f"Error loading repository: {e}") | |
logger.error(f"Load error: {e}") | |
# --- Main UI: Chat Interface --- | |
col1, col2 = st.columns([6, 1]) | |
with col1: | |
st.header("Chat with GitHub RAG") | |
with col2: | |
st.button("Clear Chat ↺", on_click=reset_chat) | |
# Display chat history | |
for msg in st.session_state.messages: | |
with st.chat_message(msg["role"]): | |
st.markdown(msg["content"]) | |
# Chat input box | |
if prompt := st.chat_input("Ask a question about the repository..."): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
repo_name = github_url.rstrip("/").split("/")[-1] | |
cache_key = f"{session_id}-{repo_name}" | |
if cache_key not in st.session_state.file_cache: | |
st.error("Please load a repository first!") | |
else: | |
query_engine, tree = st.session_state.file_cache[cache_key] | |
with st.chat_message("assistant"): | |
placeholder = st.empty() | |
response_text = "" | |
try: | |
response = query_engine.query(prompt) | |
if hasattr(response, 'response_gen'): | |
for chunk in response.response_gen: | |
response_text += chunk | |
placeholder.markdown(response_text + "▌") | |
else: | |
response_text = str(response) | |
placeholder.markdown(response_text) | |
except GitHubRAGError as e: | |
st.error(str(e)) | |
logger.error(f"Error in chat processing: {e}") | |
response_text = "Sorry, I couldn't process that request." | |
except Exception as e: | |
st.error("An unexpected error occurred while processing your query") | |
logger.error(f"Unexpected error in chat: {e}") | |
response_text = "Sorry, something went wrong." | |
placeholder.markdown(response_text) | |
st.session_state.messages.append({"role": "assistant", "content": response_text}) | |