DeepScaleR1 / app.py
AdarshHF3115's picture
Added
36e9c57 verified
import gradio as gr
import os
import time
from typing import List, Tuple, Optional
from pathlib import Path
from threading import Thread
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from langchain.docstore.document import Document
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
BitsAndBytesConfig,
StoppingCriteria,
StoppingCriteriaList,
)
import torch
EMBEDDING_MODEL = "BAAI/bge-m3"
MODEL_NAME = "agentica-org/DeepScaleR-1.5B-Preview"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_CONTEXT_LENGTH = 8192
bnb_config = (
BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
if DEVICE == "cuda"
else None
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [0]
return input_ids[0][-1] in stop_ids
def validate_file_paths(file_paths: List[str]) -> List[str]:
valid_paths = []
for path in file_paths:
try:
if Path(path).exists() and Path(path).suffix.lower() in [".pdf", ".txt"]:
valid_paths.append(path)
except (OSError, PermissionError) as e:
print(f"File validation error: {str(e)}")
return valid_paths
def load_documents(file_paths: List[str]) -> List[Document]:
documents = []
valid_paths = validate_file_paths(file_paths)
if not valid_paths:
raise ValueError("No valid PDF/TXT files found!")
for path in valid_paths:
try:
if path.endswith(".pdf"):
loader = PyPDFLoader(path)
elif path.endswith(".txt"):
loader = TextLoader(path)
docs = loader.load()
if docs:
documents.extend(docs)
except Exception as e:
print(f"Error loading {Path(path).name}: {str(e)}")
if not documents:
raise ValueError("All documents failed to load.")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=128,
length_function=len,
add_start_index=True,
separators=["\n\n", "\n", "。", " ", ""],
)
return text_splitter.split_documents(documents)
def create_vector_store(documents: List[Document]) -> FAISS:
if not documents:
raise ValueError("No documents to index.")
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={"device": DEVICE},
encode_kwargs={"normalize_embeddings": True},
)
return FAISS.from_documents(documents, embeddings)
def initialize_deepseek_model(
vector_store: FAISS,
temperature: float = 0.7,
max_new_tokens: int = 1024,
top_k: int = 50,
) -> ConversationalRetrievalChain:
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME, use_fast=True, trust_remote_code=True
)
torch_dtype = torch.float16 if DEVICE == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
text_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=1.1,
stopping_criteria=StoppingCriteriaList([StopOnTokens()]),
batch_size=1,
return_full_text=False,
)
llm = HuggingFacePipeline(
pipeline=text_pipeline, model_kwargs={"temperature": temperature}
)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer",
input_key="question",
)
return ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vector_store.as_retriever(
search_type="mmr", search_kwargs={"k": 5, "fetch_k": 10}
),
memory=memory,
chain_type="stuff",
return_source_documents=True,
verbose=False,
max_tokens_limit=MAX_CONTEXT_LENGTH,
)
except Exception as e:
raise RuntimeError(f"Model initialization failed: {str(e)}")
def format_sources(source_docs: List[Document]) -> List[Tuple[str, int]]:
sources = []
try:
for doc in source_docs[:3]:
content = doc.page_content.strip()[:500] + "..."
page = doc.metadata.get("page", 0) + 1
sources.append((content, page))
while len(sources) < 3:
sources.append(("No source found", 0))
except Exception:
return [("Source processing error", 0)] * 3
return sources
def handle_conversation(
qa_chain: Optional[ConversationalRetrievalChain],
message: str,
history: List[Tuple[str, str]],
) -> Tuple:
start_time = time.time()
if not qa_chain:
return None, "", history, *[("System Error", 0)] * 3
try:
response = qa_chain.invoke({"question": message, "chat_history": history})
answer = response["answer"].strip()
sources = format_sources(response.get("source_documents", []))
new_history = history + [(message, answer)]
elapsed = f"{(time.time() - start_time):.2f}s"
print(f"Response generated in {elapsed}")
return (
qa_chain,
"",
new_history,
*[item for sublist in sources for item in sublist],
)
except Exception as e:
error_msg = f"⚠️ Error: {str(e)}"
return qa_chain, "", history + [(message, error_msg)], *[("Error", 0)] * 3
def create_interface() -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Default()) as interface:
qa_chain = gr.State()
vector_store = gr.State()
gr.Markdown(
"""
<h1 style="text-align:center; color: #ooffff;">
DeepScale R1
</h1>
<p style="text-align:center; color: #008080;">
A Safe and Strong Local RAG System by Adarsh Pandey !!
</p>
""",
elem_id="header-section",
)
with gr.Row():
with gr.Column(scale=1, min_width=300):
gr.Markdown("### Step 1: Document Processing")
file_input = gr.Files(
file_types=[".pdf", ".txt"], file_count="multiple"
)
process_btn = gr.Button("Process Documents", variant="primary")
process_status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### Step 2: Model Configuration")
with gr.Accordion("Advanced Parameters", open=False):
temp_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
)
token_slider = gr.Slider(
minimum=256,
maximum=4096,
value=1024,
step=128,
label="Response Length",
)
topk_slider = gr.Slider(
minimum=1, maximum=100, value=50, step=5, label="Top-K Sampling"
)
init_btn = gr.Button("Initialize Model", variant="primary")
model_status = gr.Textbox(label="Model Status", interactive=False)
with gr.Column(scale=1, min_width=500):
chatbot = gr.Chatbot(
label="Conversation History",
height=450,
avatar_images=["2.png", "3.png"],
)
msg_input = gr.Textbox(
label="Your Query",
placeholder="Ask a question about your documents...",
)
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.ClearButton([msg_input, chatbot], value="Clear Chat")
with gr.Accordion("Source References", open=True):
for i in range(3):
with gr.Row():
gr.Textbox(
label=f"Reference {i+1}", max_lines=4, interactive=False
)
gr.Number(label="Page", value=0, interactive=False)
process_btn.click(
fn=lambda files: (
create_vector_store(load_documents([f.name for f in files])),
"Documents processed successfully.",
),
inputs=file_input,
outputs=[vector_store, process_status],
api_name="process_docs",
)
init_btn.click(
fn=lambda vs, temp, tokens, k: (
initialize_deepseek_model(vs, temp, tokens, k),
"Model initialized successfully.",
),
inputs=[vector_store, temp_slider, token_slider, topk_slider],
outputs=[qa_chain, model_status],
api_name="init_model",
)
msg_input.submit(
fn=handle_conversation,
inputs=[qa_chain, msg_input, chatbot],
outputs=[qa_chain, msg_input, chatbot, *(gr.Textbox(), gr.Number()) * 3],
api_name="chat",
)
submit_btn.click(
fn=handle_conversation,
inputs=[qa_chain, msg_input, chatbot],
outputs=[qa_chain, msg_input, chatbot, *(gr.Textbox(), gr.Number()) * 3],
api_name="chat",
)
return interface
if __name__ == "__main__":
app = create_interface()
app.launch(
server_name="0.0.0.0" if os.getenv("DOCKER") else "localhost",
server_port=7860,
show_error=True,
share=True,
favicon_path="1.png",
)