gauri-sharan commited on
Commit
f4736cc
·
verified ·
1 Parent(s): 8cbe701

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -2,12 +2,7 @@ import os
2
  import tempfile
3
  import torch
4
  import gradio as gr
5
-
6
- # Patch torch if needed
7
- if not hasattr(torch, "get_default_device"):
8
- def get_default_device():
9
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- torch.get_default_device = get_default_device
11
 
12
  from pinecone import Pinecone
13
  from langchain_pinecone import PineconeVectorStore
@@ -17,21 +12,25 @@ from langchain.chains import RetrievalQAWithSourcesChain
17
  from langchain_text_splitters import RecursiveCharacterTextSplitter
18
  from langchain.embeddings import HuggingFaceEmbeddings
19
 
20
- # Set device
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  # Initialize Pinecone
24
  pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
25
  INDEX_NAME = "ragreader"
26
 
27
- # Embedding initialization
28
- embeddings = HuggingFaceEmbeddings(
29
- model_name="BAAI/bge-large-en-v1.5",
30
- model_kwargs={'device': device}
31
- )
 
 
32
 
 
33
 
34
- # Document processing pipeline
 
35
  def process_documents(files):
36
  docs = []
37
  for file in files:
@@ -54,7 +53,7 @@ def process_documents(files):
54
  )
55
  return "Documents processed and stored."
56
 
57
- # Initialize QA chain
58
  def init_qa_chain():
59
  llm = ChatMistralAI(
60
  model="mistral-tiny",
@@ -76,7 +75,7 @@ def init_qa_chain():
76
 
77
  qa_chain = None
78
 
79
- # Gradio interface
80
  with gr.Blocks() as demo:
81
  gr.Markdown("## RAG Chatbot - PDF Reader")
82
 
 
2
  import tempfile
3
  import torch
4
  import gradio as gr
5
+ import spaces # Required for GPU-enabled Spaces
 
 
 
 
 
6
 
7
  from pinecone import Pinecone
8
  from langchain_pinecone import PineconeVectorStore
 
12
  from langchain_text_splitters import RecursiveCharacterTextSplitter
13
  from langchain.embeddings import HuggingFaceEmbeddings
14
 
15
+ # Set device for embeddings
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  # Initialize Pinecone
19
  pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
20
  INDEX_NAME = "ragreader"
21
 
22
+ # GPU-decorated function to load HuggingFace embeddings on GPU
23
+ @spaces.GPU
24
+ def init_embeddings():
25
+ return HuggingFaceEmbeddings(
26
+ model_name="BAAI/bge-large-en-v1.5",
27
+ model_kwargs={"device": device}
28
+ )
29
 
30
+ embeddings = init_embeddings()
31
 
32
+ # GPU-decorated document processing function
33
+ @spaces.GPU
34
  def process_documents(files):
35
  docs = []
36
  for file in files:
 
53
  )
54
  return "Documents processed and stored."
55
 
56
+ # Initialize the RetrievalQA chain (no GPU decoration needed here)
57
  def init_qa_chain():
58
  llm = ChatMistralAI(
59
  model="mistral-tiny",
 
75
 
76
  qa_chain = None
77
 
78
+ # Build the Gradio UI
79
  with gr.Blocks() as demo:
80
  gr.Markdown("## RAG Chatbot - PDF Reader")
81