Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,12 +2,7 @@ import os
|
|
2 |
import tempfile
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
-
|
6 |
-
# Patch torch if needed
|
7 |
-
if not hasattr(torch, "get_default_device"):
|
8 |
-
def get_default_device():
|
9 |
-
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
-
torch.get_default_device = get_default_device
|
11 |
|
12 |
from pinecone import Pinecone
|
13 |
from langchain_pinecone import PineconeVectorStore
|
@@ -17,21 +12,25 @@ from langchain.chains import RetrievalQAWithSourcesChain
|
|
17 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
18 |
from langchain.embeddings import HuggingFaceEmbeddings
|
19 |
|
20 |
-
# Set device
|
21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
|
23 |
# Initialize Pinecone
|
24 |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
25 |
INDEX_NAME = "ragreader"
|
26 |
|
27 |
-
#
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
32 |
|
|
|
33 |
|
34 |
-
#
|
|
|
35 |
def process_documents(files):
|
36 |
docs = []
|
37 |
for file in files:
|
@@ -54,7 +53,7 @@ def process_documents(files):
|
|
54 |
)
|
55 |
return "Documents processed and stored."
|
56 |
|
57 |
-
# Initialize
|
58 |
def init_qa_chain():
|
59 |
llm = ChatMistralAI(
|
60 |
model="mistral-tiny",
|
@@ -76,7 +75,7 @@ def init_qa_chain():
|
|
76 |
|
77 |
qa_chain = None
|
78 |
|
79 |
-
# Gradio
|
80 |
with gr.Blocks() as demo:
|
81 |
gr.Markdown("## RAG Chatbot - PDF Reader")
|
82 |
|
|
|
2 |
import tempfile
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
+
import spaces # Required for GPU-enabled Spaces
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
from pinecone import Pinecone
|
8 |
from langchain_pinecone import PineconeVectorStore
|
|
|
12 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
13 |
from langchain.embeddings import HuggingFaceEmbeddings
|
14 |
|
15 |
+
# Set device for embeddings
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
# Initialize Pinecone
|
19 |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
20 |
INDEX_NAME = "ragreader"
|
21 |
|
22 |
+
# GPU-decorated function to load HuggingFace embeddings on GPU
|
23 |
+
@spaces.GPU
|
24 |
+
def init_embeddings():
|
25 |
+
return HuggingFaceEmbeddings(
|
26 |
+
model_name="BAAI/bge-large-en-v1.5",
|
27 |
+
model_kwargs={"device": device}
|
28 |
+
)
|
29 |
|
30 |
+
embeddings = init_embeddings()
|
31 |
|
32 |
+
# GPU-decorated document processing function
|
33 |
+
@spaces.GPU
|
34 |
def process_documents(files):
|
35 |
docs = []
|
36 |
for file in files:
|
|
|
53 |
)
|
54 |
return "Documents processed and stored."
|
55 |
|
56 |
+
# Initialize the RetrievalQA chain (no GPU decoration needed here)
|
57 |
def init_qa_chain():
|
58 |
llm = ChatMistralAI(
|
59 |
model="mistral-tiny",
|
|
|
75 |
|
76 |
qa_chain = None
|
77 |
|
78 |
+
# Build the Gradio UI
|
79 |
with gr.Blocks() as demo:
|
80 |
gr.Markdown("## RAG Chatbot - PDF Reader")
|
81 |
|