File size: 3,211 Bytes
a580ec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d5c46d
 
 
 
 
 
 
 
 
 
a580ec9
 
 
 
 
 
cd5f458
a580ec9
 
 
 
 
 
 
 
 
 
 
ab1a53e
a580ec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# app.py

import os
import json
import faiss
import numpy as np
import PyPDF2
import requests
import streamlit as st
from groq import Groq

# Constants
PDF_URL = "https://drive.google.com/uc?export=download&id=1YWX-RYxgtcKO1QETnz1N3rboZUhRZwcH"
VECTOR_DIM = 768
CHUNK_SIZE = 512

# Function to download and extract text from the PDF
def extract_text_from_pdf(url):
    response = requests.get(url)
    with open("document.pdf", "wb") as f:
        f.write(response.content)

    with open("document.pdf", "rb") as f:
        reader = PyPDF2.PdfReader(f)
        text = "\n".join(page.extract_text() for page in reader.pages)
    return text

# Function to split text into chunks
def create_chunks(text, chunk_size):
    words = text.split()
    chunks = [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
    return chunks

# Function to create FAISS vector store
def create_faiss_index(chunks, vector_dim):
    # Check if GPU is available and use it
    if faiss.get_num_gpus() > 0:
        st.write("Using GPU for FAISS indexing.")
        resource = faiss.StandardGpuResources()  # Initialize GPU resources
        index_flat = faiss.IndexFlatL2(vector_dim)
        index = faiss.index_cpu_to_gpu(resource, 0, index_flat)
    else:
        st.write("Using CPU for FAISS indexing.")
        index = faiss.IndexFlatL2(vector_dim)

    embeddings = np.random.rand(len(chunks), vector_dim).astype('float32')  # Replace with real embeddings
    index.add(embeddings)
    return index, embeddings

# Initialize Groq API client
def get_groq_client():
    return Groq(api_key=os.environ.get("GROQ_API_KEY"))

# Query Groq model
def query_model(client, question):
    chat_completion = client.chat.completions.create(
        messages=[{"role": "user", "content": question}],
        model="llama-3.3-70b-versatile",
    )
    return chat_completion.choices[0].message.content

# Streamlit app
def main():
    st.title("RAG-Based ChatBot")

    # Step 1: Extract text from the document
    st.header("Step 1: Extract Text")
    if st.button("Extract Text from PDF"):
        text = extract_text_from_pdf(PDF_URL)
        st.session_state["text"] = text
        st.success("Text extracted successfully!")

    # Step 2: Chunk the text
    st.header("Step 2: Create Chunks")
    if "text" in st.session_state and st.button("Create Chunks"):
        chunks = create_chunks(st.session_state["text"], CHUNK_SIZE)
        st.session_state["chunks"] = chunks
        st.success(f"Created {len(chunks)} chunks.")

    # Step 3: Create FAISS index
    st.header("Step 3: Create Vector Database")
    if "chunks" in st.session_state and st.button("Create Vector Database"):
        index, embeddings = create_faiss_index(st.session_state["chunks"], VECTOR_DIM)
        st.session_state["index"] = index
        st.success("FAISS vector database created.")

    # Step 4: Ask a question
    st.header("Step 4: Query the Model")
    question = st.text_input("Ask a question about the document:")
    if question and "index" in st.session_state:
        client = get_groq_client()
        answer = query_model(client, question)
        st.write("Answer:", answer)

if __name__ == "__main__":
    main()