gpt-oss-RAG / app.py
openfree's picture
Update app.py
a4e794b verified
import gradio as gr
import spaces
import os
from typing import List, Dict, Any, Optional, Tuple
import hashlib
from datetime import datetime
import numpy as np
from transformers import pipeline, TextIteratorStreamer
import torch
from threading import Thread
import re
# PDF 처리 라이브러리
try:
import fitz # PyMuPDF
PDF_AVAILABLE = True
except ImportError:
PDF_AVAILABLE = False
print("⚠️ PyMuPDF not installed. Install with: pip install pymupdf")
try:
from sentence_transformers import SentenceTransformer
ST_AVAILABLE = True
except ImportError:
ST_AVAILABLE = False
print("⚠️ Sentence Transformers not installed. Install with: pip install sentence-transformers")
# Custom CSS
custom_css = """
.gradio-container {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
min-height: 100vh;
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
}
.main-container {
background: rgba(255, 255, 255, 0.98);
border-radius: 16px;
padding: 20px;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
border: 1px solid rgba(0, 0, 0, 0.05);
}
.sidebar-container {
background: rgba(255, 255, 255, 0.98);
border-radius: 12px;
padding: 16px;
box-shadow: 0 2px 4px -1px rgba(0, 0, 0, 0.06);
border: 1px solid rgba(0, 0, 0, 0.05);
height: fit-content;
}
.pdf-status {
padding: 10px 14px;
border-radius: 10px;
margin: 8px 0;
font-size: 0.9rem;
font-weight: 500;
}
.pdf-success {
background: linear-gradient(135deg, #d4edda 0%, #c3e6cb 100%);
border: 1px solid #b1dfbb;
color: #155724;
}
.pdf-error {
background: linear-gradient(135deg, #f8d7da 0%, #f5c6cb 100%);
border: 1px solid #f1aeb5;
color: #721c24;
}
.pdf-info {
background: linear-gradient(135deg, #d1ecf1 0%, #bee5eb 100%);
border: 1px solid #9ec5d8;
color: #0c5460;
}
.rag-context {
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
border-left: 4px solid #f59e0b;
padding: 10px;
margin: 8px 0;
border-radius: 6px;
font-size: 0.85rem;
}
.status-badge {
display: inline-block;
padding: 4px 12px;
border-radius: 20px;
font-size: 0.85rem;
font-weight: 600;
margin: 4px 0;
}
.status-enabled {
background: #10b981;
color: white;
}
.status-disabled {
background: #6b7280;
color: white;
}
/* Chat interface maximization */
.chat-container {
height: calc(100vh - 200px) !important;
min-height: 600px;
}
/* Accordion styling */
.accordion {
margin: 8px 0;
}
"""
class SimpleTextSplitter:
"""텍스트 분할기"""
def __init__(self, chunk_size=800, chunk_overlap=100):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def split_text(self, text: str) -> List[str]:
"""텍스트를 청크로 분할"""
chunks = []
sentences = text.split('. ')
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < self.chunk_size:
current_chunk += sentence + ". "
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + ". "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
class PDFRAGSystem:
"""PDF 기반 RAG 시스템"""
def __init__(self):
self.documents = {}
self.document_chunks = {}
self.embeddings_store = {}
self.text_splitter = SimpleTextSplitter(chunk_size=800, chunk_overlap=100)
# 임베딩 모델 초기화
self.embedder = None
if ST_AVAILABLE:
try:
self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
print("✅ 임베딩 모델 로드 성공")
except Exception as e:
print(f"⚠️ 임베딩 모델 로드 실패: {e}")
def extract_text_from_pdf(self, pdf_path: str) -> Dict[str, Any]:
"""PDF에서 텍스트 추출"""
if not PDF_AVAILABLE:
return {
"metadata": {
"title": "PDF Reader Not Available",
"file_name": os.path.basename(pdf_path),
"pages": 0
},
"full_text": "PDF 처리를 위해 'pip install pymupdf'를 실행해주세요."
}
try:
doc = fitz.open(pdf_path)
text_content = []
metadata = {
"title": doc.metadata.get("title", os.path.basename(pdf_path)),
"pages": len(doc),
"file_name": os.path.basename(pdf_path)
}
for page_num, page in enumerate(doc):
text = page.get_text()
if text.strip():
text_content.append(text)
doc.close()
return {
"metadata": metadata,
"full_text": "\n\n".join(text_content)
}
except Exception as e:
raise Exception(f"PDF 처리 오류: {str(e)}")
def process_and_store_pdf(self, pdf_path: str, doc_id: str) -> Dict[str, Any]:
"""PDF 처리 및 저장"""
try:
# PDF 텍스트 추출
pdf_data = self.extract_text_from_pdf(pdf_path)
# 텍스트를 청크로 분할
chunks = self.text_splitter.split_text(pdf_data["full_text"])
if not chunks:
print("Warning: No chunks created from PDF")
return {"success": False, "error": "No text content found in PDF"}
print(f"Created {len(chunks)} chunks from PDF")
# 청크 저장
self.document_chunks[doc_id] = chunks
# 임베딩 생성 (선택적)
if self.embedder:
try:
print("Generating embeddings...")
embeddings = self.embedder.encode(chunks)
self.embeddings_store[doc_id] = embeddings
print(f"Generated {len(embeddings)} embeddings")
except Exception as e:
print(f"Warning: Failed to generate embeddings: {e}")
# 임베딩 실패해도 계속 진행
# 문서 정보 저장
self.documents[doc_id] = {
"metadata": pdf_data["metadata"],
"chunk_count": len(chunks),
"upload_time": datetime.now().isoformat()
}
# 디버그: 첫 번째 청크 출력
print(f"First chunk preview: {chunks[0][:200]}...")
return {
"success": True,
"doc_id": doc_id,
"chunks": len(chunks),
"pages": pdf_data["metadata"]["pages"],
"title": pdf_data["metadata"]["title"]
}
except Exception as e:
print(f"Error processing PDF: {e}")
return {"success": False, "error": str(e)}
def search_relevant_chunks(self, query: str, doc_ids: List[str], top_k: int = 3) -> List[Dict]:
"""관련 청크 검색"""
all_relevant_chunks = []
print(f"Searching chunks for query: '{query[:50]}...' in {len(doc_ids)} documents")
# 먼저 문서가 있는지 확인
for doc_id in doc_ids:
if doc_id not in self.document_chunks:
print(f"Warning: Document {doc_id} not found in chunks")
continue
chunks = self.document_chunks[doc_id]
print(f"Document {doc_id} has {len(chunks)} chunks")
# 임베딩 기반 검색 시도
if self.embedder and doc_id in self.embeddings_store:
try:
query_embedding = self.embedder.encode([query])[0]
doc_embeddings = self.embeddings_store[doc_id]
# 코사인 유사도 계산 (안전하게)
similarities = []
for i, emb in enumerate(doc_embeddings):
try:
query_norm = np.linalg.norm(query_embedding)
emb_norm = np.linalg.norm(emb)
if query_norm > 0 and emb_norm > 0:
sim = np.dot(query_embedding, emb) / (query_norm * emb_norm)
similarities.append(sim)
else:
similarities.append(0.0)
except Exception as e:
print(f"Error calculating similarity for chunk {i}: {e}")
similarities.append(0.0)
# 상위 청크 선택
if similarities:
top_indices = np.argsort(similarities)[-min(top_k, len(similarities)):][::-1]
for idx in top_indices:
if idx < len(chunks): # 인덱스 범위 확인
all_relevant_chunks.append({
"content": chunks[idx],
"doc_name": self.documents[doc_id]["metadata"]["file_name"],
"similarity": similarities[idx]
})
print(f"Added chunk {idx} with similarity: {similarities[idx]:.3f}")
except Exception as e:
print(f"Error in embedding search: {e}")
# 임베딩 실패시 폴백
# 임베딩이 없거나 실패한 경우 - 간단히 처음 N개 청크 반환
if not all_relevant_chunks:
print(f"Falling back to simple chunk selection for {doc_id}")
for i in range(min(top_k, len(chunks))):
all_relevant_chunks.append({
"content": chunks[i],
"doc_name": self.documents[doc_id]["metadata"]["file_name"],
"similarity": 1.0 - (i * 0.1) # 순서대로 가중치
})
print(f"Added chunk {i} (fallback)")
# 유사도 기준 정렬
all_relevant_chunks.sort(key=lambda x: x.get('similarity', 0), reverse=True)
# 상위 K개 선택
result = all_relevant_chunks[:top_k]
print(f"Returning {len(result)} chunks")
# 디버그: 첫 번째 청크 내용 일부 출력
if result:
print(f"First chunk preview: {result[0]['content'][:100]}...")
return result
def create_rag_prompt(self, query: str, doc_ids: List[str], top_k: int = 3) -> tuple:
"""RAG 프롬프트 생성 - 쿼리와 컨텍스트를 분리하여 반환"""
print(f"Creating RAG prompt for query: '{query[:50]}...' with docs: {doc_ids}")
relevant_chunks = self.search_relevant_chunks(query, doc_ids, top_k)
if not relevant_chunks:
print("No relevant chunks found - checking if documents exist")
# 문서가 있는데 청크를 못 찾은 경우, 첫 번째 청크라도 사용
for doc_id in doc_ids:
if doc_id in self.document_chunks and self.document_chunks[doc_id]:
print(f"Using first chunk from {doc_id} as fallback")
relevant_chunks = [{
"content": self.document_chunks[doc_id][0],
"doc_name": self.documents[doc_id]["metadata"]["file_name"],
"similarity": 0.5
}]
break
if not relevant_chunks:
print("No documents or chunks available")
return query, ""
print(f"Using {len(relevant_chunks)} chunks for context")
# 컨텍스트 구성
context_parts = []
context_parts.append("Based on the following document context, please answer the question below:")
context_parts.append("=" * 40)
for i, chunk in enumerate(relevant_chunks, 1):
context_parts.append(f"\n[Document Reference {i} - {chunk['doc_name']}]")
# 청크 크기 증가
content = chunk['content'][:1000] if len(chunk['content']) > 1000 else chunk['content']
context_parts.append(content)
print(f"Added chunk {i} ({len(content)} chars) with similarity: {chunk.get('similarity', 0):.3f}")
context_parts.append("\n" + "=" * 40)
context = "\n".join(context_parts)
enhanced_query = f"{context}\n\nQuestion: {query}\n\nAnswer based on the document context provided above:"
print(f"Enhanced query length: {len(enhanced_query)} chars (original: {len(query)} chars)")
return enhanced_query, context
# Initialize model and RAG system
model_id = "openai/gpt-oss-20b"
pipe = pipeline(
"text-generation",
model=model_id,
torch_dtype="auto",
device_map="auto",
)
rag_system = PDFRAGSystem()
# Global state for RAG
rag_enabled = False
selected_docs = []
top_k_chunks = 3
last_context = ""
def format_conversation_history(chat_history):
"""Format conversation history for the model"""
messages = []
for item in chat_history:
role = item["role"]
content = item["content"]
if isinstance(content, list):
content = content[0]["text"] if content and "text" in content[0] else str(content)
messages.append({"role": role, "content": content})
return messages
@spaces.GPU()
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
"""Generate response with optional RAG enhancement"""
global last_context, rag_enabled, selected_docs, top_k_chunks
# Debug logging
print(f"RAG Enabled: {rag_enabled}")
print(f"Selected Docs: {selected_docs}")
print(f"Available Docs: {list(rag_system.documents.keys())}")
# Apply RAG if enabled
if rag_enabled and selected_docs:
doc_ids = [doc.split(":")[0] for doc in selected_docs]
enhanced_input, context = rag_system.create_rag_prompt(input_data, doc_ids, top_k_chunks)
last_context = context
actual_input = enhanced_input
print(f"RAG Applied - Original: {len(input_data)} chars, Enhanced: {len(enhanced_input)} chars")
else:
actual_input = input_data
last_context = ""
print("RAG Not Applied")
# Prepare messages
new_message = {"role": "user", "content": actual_input}
system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
processed_history = format_conversation_history(chat_history)
messages = system_message + processed_history + [new_message]
# Setup streaming
streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"streamer": streamer
}
thread = Thread(target=pipe, args=(messages,), kwargs=generation_kwargs)
thread.start()
# Process streaming output
thinking = ""
final = ""
started_final = False
for chunk in streamer:
if not started_final:
if "assistantfinal" in chunk.lower():
split_parts = re.split(r'assistantfinal', chunk, maxsplit=1)
thinking += split_parts[0]
final += split_parts[1]
started_final = True
else:
thinking += chunk
else:
final += chunk
clean_thinking = re.sub(r'^analysis\s*', '', thinking).strip()
clean_final = final.strip()
# Add RAG context indicator if used
rag_indicator = ""
if rag_enabled and selected_docs and last_context:
rag_indicator = "<div class='rag-context'>📚 RAG Context Applied</div>\n\n"
formatted = f"{rag_indicator}<details open><summary>Click to view Thinking Process</summary>\n\n{clean_thinking}\n\n</details>\n\n{clean_final}"
yield formatted
def upload_pdf(file):
"""PDF 파일 업로드 처리"""
if file is None:
return (
gr.update(value="<div class='pdf-status pdf-info'>📁 파일을 선택해주세요</div>"),
gr.update(choices=[])
)
try:
# 파일 해시를 ID로 사용
with open(file.name, 'rb') as f:
file_hash = hashlib.md5(f.read()).hexdigest()[:8]
doc_id = f"doc_{file_hash}"
# PDF 처리 및 저장
result = rag_system.process_and_store_pdf(file.name, doc_id)
if result["success"]:
status_html = f"""
<div class="pdf-status pdf-success">
✅ PDF 업로드 완료!<br>
📄 {result['title']}<br>
📑 {result['pages']} 페이지 | 🔍 {result['chunks']} 청크
</div>
"""
# 문서 목록 업데이트
doc_choices = [f"{doc_id}: {rag_system.documents[doc_id]['metadata']['file_name']}"
for doc_id in rag_system.documents.keys()]
return (
status_html,
gr.update(choices=doc_choices, value=doc_choices)
)
else:
return (
f"<div class='pdf-status pdf-error'>❌ 오류: {result['error']}</div>",
gr.update()
)
except Exception as e:
return (
f"<div class='pdf-status pdf-error'>❌ 오류: {str(e)}</div>",
gr.update()
)
def clear_documents():
"""문서 초기화"""
global selected_docs
rag_system.documents = {}
rag_system.document_chunks = {}
rag_system.embeddings_store = {}
selected_docs = []
return (
gr.update(value="<div class='pdf-status pdf-info'>🗑️ 모든 문서가 삭제되었습니다</div>"),
gr.update(choices=[], value=[])
)
def update_rag_settings(enable, docs, k):
"""Update RAG settings"""
global rag_enabled, selected_docs, top_k_chunks
rag_enabled = enable
selected_docs = docs if docs else []
top_k_chunks = k
# Debug logging
print(f"RAG Settings Updated - Enabled: {rag_enabled}, Docs: {selected_docs}, Top-K: {top_k_chunks}")
status = "✅ Enabled" if enable and docs else "⭕ Disabled"
status_html = f"<div class='pdf-status pdf-info'>🔍 RAG: <strong>{status}</strong></div>"
# Show context preview if RAG is enabled
if enable and docs:
preview = f"<div class='rag-context'>📚 Using {len(docs)} document(s) with {k} chunks per query</div>"
return gr.update(value=status_html), gr.update(value=preview, visible=True)
else:
return gr.update(value=status_html), gr.update(value="", visible=False)
# Build the interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, fill_height=True) as demo:
with gr.Row():
# Compact sidebar
with gr.Column(scale=1, min_width=300):
gr.Markdown("## 🚀 GPT-OSS-20B + RAG")
# RAG Status Badge
with gr.Group(elem_classes="sidebar-container"):
rag_status = gr.HTML(
value="<div class='status-badge status-disabled'>RAG: Disabled</div>"
)
context_preview = gr.HTML(value="", visible=False)
# PDF Upload Section
with gr.Accordion("📄 PDF Documents", open=True, elem_classes="accordion"):
pdf_upload = gr.File(
label="Upload PDF",
file_types=[".pdf"],
type="filepath",
elem_classes="compact-upload"
)
upload_status = gr.HTML(
value="<div style='font-size: 0.85rem; color: #6b7280;'>No documents uploaded</div>"
)
document_list = gr.CheckboxGroup(
choices=[],
label="Select Documents",
elem_classes="compact-checkbox"
)
with gr.Row():
enable_rag = gr.Checkbox(
label="Enable RAG",
value=False,
scale=2
)
clear_btn = gr.Button("Clear", size="sm", variant="secondary", scale=1)
# RAG Settings
with gr.Accordion("⚙️ RAG Settings", open=False, elem_classes="accordion"):
top_k_slider = gr.Slider(
minimum=1,
maximum=5,
value=3,
step=1,
label="Context Chunks",
info="Number of chunks to use"
)
# Model Settings
with gr.Accordion("🔧 Model Settings", open=False, elem_classes="accordion"):
max_tokens = gr.Slider(
label="Max tokens",
minimum=64,
maximum=4096,
step=1,
value=2048
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=2.0,
step=0.1,
value=0.7
)
with gr.Row():
top_p = gr.Slider(
label="Top-p",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9
)
top_k = gr.Slider(
label="Top-k",
minimum=1,
maximum=100,
step=1,
value=50
)
repetition_penalty = gr.Slider(
label="Repetition Penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0
)
# System Prompt
with gr.Accordion("💬 System Prompt", open=False, elem_classes="accordion"):
system_prompt = gr.Textbox(
label="System Prompt",
value="You are a helpful assistant. Reasoning: medium",
lines=3,
placeholder="Customize the system prompt..."
)
# Main chat area - maximized
with gr.Column(scale=4):
with gr.Group(elem_classes="main-container chat-container"):
# Create ChatInterface with custom function
chat_interface = gr.ChatInterface(
fn=generate_response,
additional_inputs=[
max_tokens,
system_prompt,
temperature,
top_p,
top_k,
repetition_penalty
],
examples=[
[{"text": "Summarize the document"}],
[{"text": "What are the key points mentioned?"}],
[{"text": "Explain the main concept"}],
],
cache_examples=False,
type="messages",
title=None,
description=None,
textbox=gr.Textbox(
placeholder="Ask anything... (RAG will be applied if enabled)",
container=False,
scale=7
),
chatbot=gr.Chatbot(
height=550,
show_copy_button=True,
render_markdown=True,
type="messages"
),
submit_btn="Send",
stop_btn="Stop",
multimodal=False
)
# Event handlers
pdf_upload.upload(
fn=upload_pdf,
inputs=[pdf_upload],
outputs=[upload_status, document_list]
)
clear_btn.click(
fn=clear_documents,
outputs=[upload_status, document_list]
)
# Simplified RAG status update
def update_rag_status_simple(enable, docs, k):
"""Simplified RAG status update"""
global rag_enabled, selected_docs, top_k_chunks
rag_enabled = enable
selected_docs = docs if docs else []
top_k_chunks = k
if enable and docs:
status_html = "<div class='status-badge status-enabled'>RAG: Active</div>"
preview = f"<div style='font-size: 0.85rem; color: #10b981;'>📚 {len(docs)} doc(s) | {k} chunks</div>"
return gr.update(value=status_html), gr.update(value=preview, visible=True)
else:
status_html = "<div class='status-badge status-disabled'>RAG: Disabled</div>"
return gr.update(value=status_html), gr.update(value="", visible=False)
# Update RAG settings when changed
enable_rag.change(
fn=update_rag_status_simple,
inputs=[enable_rag, document_list, top_k_slider],
outputs=[rag_status, context_preview]
)
document_list.change(
fn=update_rag_status_simple,
inputs=[enable_rag, document_list, top_k_slider],
outputs=[rag_status, context_preview]
)
top_k_slider.change(
fn=update_rag_status_simple,
inputs=[enable_rag, document_list, top_k_slider],
outputs=[rag_status, context_preview]
)
if __name__ == "__main__":
demo.launch(share=True)